From 95233e7006d224b8dbd0a05c94dce7564dc590aa Mon Sep 17 00:00:00 2001 From: Felipe Date: Thu, 19 Jan 2023 11:37:31 -0800 Subject: [PATCH 1/4] Address warnings issue --- sdv/multi_table/base.py | 46 ++++++++++++++++ tests/unit/multi_table/test_base.py | 84 +++++++++++++++++++++++++++++ 2 files changed, 130 insertions(+) diff --git a/sdv/multi_table/base.py b/sdv/multi_table/base.py index 35e597e6c..93b2bc3ec 100644 --- a/sdv/multi_table/base.py +++ b/sdv/multi_table/base.py @@ -7,6 +7,7 @@ import numpy as np import pandas as pd +from sdv.errors import SynthesizerInputError from sdv.single_table.copulas import GaussianCopulaSynthesizer from sdv.single_table.errors import InvalidDataError @@ -359,3 +360,48 @@ def get_learned_distributions(self, table_name): """ synthesizer = self._table_synthesizers[table_name] return synthesizer.get_learned_distributions() + + def add_constraints(self, constraints): + """Add constraints to the synthesizer. + + Args: + constraints (list): + List of constraints described as dictionaries in the following format: + * ``constraint_class``: Name of the constraint to apply. + * ``table_name``: Name of the table where to apply the constraint. + * ``constraint_parameters``: A dictionary with the constraint parameters. + + Raises: + SynthesizerInputError: + Raises when the ``Unique`` constraint is passed. + """ + if self._fitted: + warnings.warn( + "For these constraints to take effect, please refit the synthesizer using 'fit'." + ) + + for constraint in constraints: + if constraint['constraint_class'] == 'Unique': + raise SynthesizerInputError( + "The constraint class 'Unique' is not currently supported." + 'Please remove the constraint for this synthesizer.' + ) + + constraint = deepcopy(constraint) + synthesizer = self._table_synthesizers[constraint.pop('table_name')] + synthesizer._data_processor.add_constraints([constraint]) + + def get_constraints(self): + """Get constraints of the synthesizer. + + Returns: + list: + List of dictionaries describing the constraints of the synthesizer. + """ + constraints = [] + for table_name, synthesizer in self._table_synthesizers.items(): + for constraint in synthesizer.get_constraints(): + constraint['table_name'] = table_name + constraints.append(constraint) + + return constraints diff --git a/tests/unit/multi_table/test_base.py b/tests/unit/multi_table/test_base.py index a85f098c3..28ee4402e 100644 --- a/tests/unit/multi_table/test_base.py +++ b/tests/unit/multi_table/test_base.py @@ -6,6 +6,7 @@ import pandas as pd import pytest +from sdv.errors import SynthesizerInputError from sdv.multi_table.base import BaseMultiTableSynthesizer from sdv.single_table.copulas import GaussianCopulaSynthesizer from sdv.single_table.errors import InvalidDataError @@ -659,3 +660,86 @@ def test_get_learned_distributions_raises_an_error(self): ) with pytest.raises(ValueError, match=error_msg): instance.get_learned_distributions('nesreca') + + def test_add_constraints(self): + """Test a list of constraits can be added to the synthesizer.""" + # Setup + metadata = get_multi_table_metadata() + instance = BaseMultiTableSynthesizer(metadata) + positive_constraint = { + 'constraint_class': 'Positive', + 'table_name': 'nesreca', + 'constraint_parameters': { + 'column_name': 'id_nesreca', + 'strict_boundaries': True + } + } + negative_constraint = { + 'constraint_class': 'Negative', + 'table_name': 'oseba', + 'constraint_parameters': { + 'column_name': 'id_nesreca', + 'strict_boundaries': False + } + } + + # Run + instance.add_constraints([positive_constraint, negative_constraint]) + + # Assert + positive_constraint.pop('table_name') + negative_constraint.pop('table_name') + assert instance._table_synthesizers['nesreca'].get_constraints() == [positive_constraint] + assert instance._table_synthesizers['oseba'].get_constraints() == [negative_constraint] + + def test_add_constraints_unique(self): + """Test an error is raised when a ``Unique`` constraint is passed.""" + # Setup + metadata = get_multi_table_metadata() + instance = BaseMultiTableSynthesizer(metadata) + unique_constraint = { + 'constraint_class': 'Unique', + 'table_name': 'oseba', + 'constraint_parameters': { + 'column_name': 'id_nesreca', + 'strict_boundaries': False + } + } + + # Run and Assert + err_msg = re.escape( + "The constraint class 'Unique' is not currently supported." + 'Please remove the constraint for this synthesizer.' + ) + with pytest.raises(SynthesizerInputError, match=err_msg): + instance.add_constraints([unique_constraint]) + + def test_get_constraints(self): + """Test a list of constraits is returned by the method.""" + # Setup + metadata = get_multi_table_metadata() + instance = BaseMultiTableSynthesizer(metadata) + positive_constraint = { + 'constraint_class': 'Positive', + 'table_name': 'nesreca', + 'constraint_parameters': { + 'column_name': 'id_nesreca', + 'strict_boundaries': True + } + } + negative_constraint = { + 'constraint_class': 'Negative', + 'table_name': 'oseba', + 'constraint_parameters': { + 'column_name': 'id_nesreca', + 'strict_boundaries': False + } + } + constraints = [positive_constraint, negative_constraint] + + # Run + instance.add_constraints(constraints) + output = instance.get_constraints() + + # Assert + assert output == constraints From 002ae89a2e3f0ed85a5386dbdd702fa9f8744764 Mon Sep 17 00:00:00 2001 From: Felipe Date: Fri, 20 Jan 2023 09:49:37 -0800 Subject: [PATCH 2/4] Add tests to single table + address feedback --- sdv/multi_table/base.py | 13 +++-- tests/unit/multi_table/test_base.py | 35 ++++++++++-- tests/unit/single_table/test_base.py | 84 ++++++++++++++++++++++++++++ 3 files changed, 121 insertions(+), 11 deletions(-) diff --git a/sdv/multi_table/base.py b/sdv/multi_table/base.py index 93b2bc3ec..4a52eb9cb 100644 --- a/sdv/multi_table/base.py +++ b/sdv/multi_table/base.py @@ -375,18 +375,19 @@ def add_constraints(self, constraints): SynthesizerInputError: Raises when the ``Unique`` constraint is passed. """ - if self._fitted: - warnings.warn( - "For these constraints to take effect, please refit the synthesizer using 'fit'." - ) - for constraint in constraints: if constraint['constraint_class'] == 'Unique': raise SynthesizerInputError( "The constraint class 'Unique' is not currently supported." - 'Please remove the constraint for this synthesizer.' + ' Please remove the constraint for this synthesizer.' ) + if self._fitted: + warnings.warn( + "For these constraints to take effect, please refit the synthesizer using 'fit'." + ) + + for constraint in constraints: constraint = deepcopy(constraint) synthesizer = self._table_synthesizers[constraint.pop('table_name')] synthesizer._data_processor.add_constraints([constraint]) diff --git a/tests/unit/multi_table/test_base.py b/tests/unit/multi_table/test_base.py index 28ee4402e..419e86cfd 100644 --- a/tests/unit/multi_table/test_base.py +++ b/tests/unit/multi_table/test_base.py @@ -661,6 +661,20 @@ def test_get_learned_distributions_raises_an_error(self): with pytest.raises(ValueError, match=error_msg): instance.get_learned_distributions('nesreca') + def test_add_constraint_warning(self): + """Test a warning is raised when the synthesizer had already been fitted.""" + # Setup + metadata = get_multi_table_metadata() + instance = BaseMultiTableSynthesizer(metadata) + instance._fitted = True + + # Run and Assert + warn_msg = ( + "For these constraints to take effect, please refit the synthesizer using 'fit'." + ) + with pytest.warns(UserWarning, match=warn_msg): + instance.add_constraints([]) + def test_add_constraints(self): """Test a list of constraits can be added to the synthesizer.""" # Setup @@ -687,8 +701,20 @@ def test_add_constraints(self): instance.add_constraints([positive_constraint, negative_constraint]) # Assert - positive_constraint.pop('table_name') - negative_constraint.pop('table_name') + positive_constraint = { + 'constraint_class': 'Positive', + 'constraint_parameters': { + 'column_name': 'id_nesreca', + 'strict_boundaries': True + } + } + negative_constraint = { + 'constraint_class': 'Negative', + 'constraint_parameters': { + 'column_name': 'id_nesreca', + 'strict_boundaries': False + } + } assert instance._table_synthesizers['nesreca'].get_constraints() == [positive_constraint] assert instance._table_synthesizers['oseba'].get_constraints() == [negative_constraint] @@ -702,14 +728,13 @@ def test_add_constraints_unique(self): 'table_name': 'oseba', 'constraint_parameters': { 'column_name': 'id_nesreca', - 'strict_boundaries': False } } # Run and Assert err_msg = re.escape( "The constraint class 'Unique' is not currently supported." - 'Please remove the constraint for this synthesizer.' + ' Please remove the constraint for this synthesizer.' ) with pytest.raises(SynthesizerInputError, match=err_msg): instance.add_constraints([unique_constraint]) @@ -736,9 +761,9 @@ def test_get_constraints(self): } } constraints = [positive_constraint, negative_constraint] + instance.add_constraints(constraints) # Run - instance.add_constraints(constraints) output = instance.get_constraints() # Assert diff --git a/tests/unit/single_table/test_base.py b/tests/unit/single_table/test_base.py index 7e726e56f..36fd74280 100644 --- a/tests/unit/single_table/test_base.py +++ b/tests/unit/single_table/test_base.py @@ -1785,3 +1785,87 @@ def test_load(self): assert instance.metadata._sequence_key is None assert instance.metadata._sequence_index is None assert instance.metadata._version == 'SINGLE_TABLE_V1' + + def test_add_constraint_warning(self): + """Test a warning is raised when the synthesizer had already been fitted.""" + # Setup + metadata = SingleTableMetadata() + instance = BaseSingleTableSynthesizer(metadata) + instance._fitted = True + + # Run and Assert + warn_msg = ( + "For these constraints to take effect, please refit the synthesizer using 'fit'." + ) + with pytest.warns(UserWarning, match=warn_msg): + instance.add_constraints([]) + + def test_add_constraints(self): + """Test a list of constraits can be added to the synthesizer.""" + # Setup + metadata = SingleTableMetadata() + metadata.add_column('col', sdtype='numerical') + instance = BaseSingleTableSynthesizer(metadata) + positive_constraint = { + 'constraint_class': 'Positive', + 'constraint_parameters': { + 'column_name': 'col', + 'strict_boundaries': True + } + } + negative_constraint = { + 'constraint_class': 'Negative', + 'constraint_parameters': { + 'column_name': 'col', + 'strict_boundaries': False + } + } + + # Run + instance.add_constraints([positive_constraint, negative_constraint]) + + # Assert + positive_constraint = { + 'constraint_class': 'Positive', + 'constraint_parameters': { + 'column_name': 'col', + 'strict_boundaries': True + } + } + negative_constraint = { + 'constraint_class': 'Negative', + 'constraint_parameters': { + 'column_name': 'col', + 'strict_boundaries': False + } + } + assert instance.get_constraints() == [positive_constraint, negative_constraint] + + def test_get_constraints(self): + """Test a list of constraits is returned by the method.""" + # Setup + metadata = SingleTableMetadata() + metadata.add_column('col', sdtype='numerical') + instance = BaseSingleTableSynthesizer(metadata) + positive_constraint = { + 'constraint_class': 'Positive', + 'constraint_parameters': { + 'column_name': 'col', + 'strict_boundaries': True + } + } + negative_constraint = { + 'constraint_class': 'Negative', + 'constraint_parameters': { + 'column_name': 'col', + 'strict_boundaries': False + } + } + constraints = [positive_constraint, negative_constraint] + instance.add_constraints(constraints) + + # Run + output = instance.get_constraints() + + # Assert + assert output == constraints From e09bf98ea0465555c6d20661157a7b6a7e159db1 Mon Sep 17 00:00:00 2001 From: Felipe Date: Fri, 20 Jan 2023 10:13:02 -0800 Subject: [PATCH 3/4] Update err msg --- sdv/multi_table/base.py | 4 ++-- tests/unit/multi_table/test_base.py | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/sdv/multi_table/base.py b/sdv/multi_table/base.py index 4a52eb9cb..4331d6ebc 100644 --- a/sdv/multi_table/base.py +++ b/sdv/multi_table/base.py @@ -378,8 +378,8 @@ def add_constraints(self, constraints): for constraint in constraints: if constraint['constraint_class'] == 'Unique': raise SynthesizerInputError( - "The constraint class 'Unique' is not currently supported." - ' Please remove the constraint for this synthesizer.' + "The constraint class 'Unique' is not currently supported for multi-table" + ' synthesizers. Please remove the constraint for this synthesizer.' ) if self._fitted: diff --git a/tests/unit/multi_table/test_base.py b/tests/unit/multi_table/test_base.py index 419e86cfd..9eeca65d2 100644 --- a/tests/unit/multi_table/test_base.py +++ b/tests/unit/multi_table/test_base.py @@ -733,8 +733,8 @@ def test_add_constraints_unique(self): # Run and Assert err_msg = re.escape( - "The constraint class 'Unique' is not currently supported." - ' Please remove the constraint for this synthesizer.' + "The constraint class 'Unique' is not currently supported for multi-table" + ' synthesizers. Please remove the constraint for this synthesizer.' ) with pytest.raises(SynthesizerInputError, match=err_msg): instance.add_constraints([unique_constraint]) From 88865a5437d91bddf7feb94e261b01d6c3d39787 Mon Sep 17 00:00:00 2001 From: Felipe Date: Fri, 20 Jan 2023 12:29:47 -0800 Subject: [PATCH 4/4] Fix typo --- tests/unit/multi_table/test_base.py | 4 ++-- tests/unit/single_table/test_base.py | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/tests/unit/multi_table/test_base.py b/tests/unit/multi_table/test_base.py index 9eeca65d2..2a6c05824 100644 --- a/tests/unit/multi_table/test_base.py +++ b/tests/unit/multi_table/test_base.py @@ -676,7 +676,7 @@ def test_add_constraint_warning(self): instance.add_constraints([]) def test_add_constraints(self): - """Test a list of constraits can be added to the synthesizer.""" + """Test a list of constraints can be added to the synthesizer.""" # Setup metadata = get_multi_table_metadata() instance = BaseMultiTableSynthesizer(metadata) @@ -740,7 +740,7 @@ def test_add_constraints_unique(self): instance.add_constraints([unique_constraint]) def test_get_constraints(self): - """Test a list of constraits is returned by the method.""" + """Test a list of constraints is returned by the method.""" # Setup metadata = get_multi_table_metadata() instance = BaseMultiTableSynthesizer(metadata) diff --git a/tests/unit/single_table/test_base.py b/tests/unit/single_table/test_base.py index 36fd74280..3a8410434 100644 --- a/tests/unit/single_table/test_base.py +++ b/tests/unit/single_table/test_base.py @@ -1801,7 +1801,7 @@ def test_add_constraint_warning(self): instance.add_constraints([]) def test_add_constraints(self): - """Test a list of constraits can be added to the synthesizer.""" + """Test a list of constraints can be added to the synthesizer.""" # Setup metadata = SingleTableMetadata() metadata.add_column('col', sdtype='numerical') @@ -1842,7 +1842,7 @@ def test_add_constraints(self): assert instance.get_constraints() == [positive_constraint, negative_constraint] def test_get_constraints(self): - """Test a list of constraits is returned by the method.""" + """Test a list of constraints is returned by the method.""" # Setup metadata = SingleTableMetadata() metadata.add_column('col', sdtype='numerical')