From 129d3ba851d728bff081a39d619083e4d7603491 Mon Sep 17 00:00:00 2001 From: Felipe Alex Hofmann Date: Fri, 20 Jan 2023 13:32:47 -0800 Subject: [PATCH] Add constraint methods to `BaseMultiTableSynthesizer` (#1178) * Address warnings issue * Add tests to single table + address feedback * Update err msg * Fix typo --- sdv/multi_table/base.py | 47 ++++++++++++ tests/unit/multi_table/test_base.py | 109 +++++++++++++++++++++++++++ tests/unit/single_table/test_base.py | 84 +++++++++++++++++++++ 3 files changed, 240 insertions(+) diff --git a/sdv/multi_table/base.py b/sdv/multi_table/base.py index 35e597e6c..4331d6ebc 100644 --- a/sdv/multi_table/base.py +++ b/sdv/multi_table/base.py @@ -7,6 +7,7 @@ import numpy as np import pandas as pd +from sdv.errors import SynthesizerInputError from sdv.single_table.copulas import GaussianCopulaSynthesizer from sdv.single_table.errors import InvalidDataError @@ -359,3 +360,49 @@ def get_learned_distributions(self, table_name): """ synthesizer = self._table_synthesizers[table_name] return synthesizer.get_learned_distributions() + + def add_constraints(self, constraints): + """Add constraints to the synthesizer. + + Args: + constraints (list): + List of constraints described as dictionaries in the following format: + * ``constraint_class``: Name of the constraint to apply. + * ``table_name``: Name of the table where to apply the constraint. + * ``constraint_parameters``: A dictionary with the constraint parameters. + + Raises: + SynthesizerInputError: + Raises when the ``Unique`` constraint is passed. + """ + for constraint in constraints: + if constraint['constraint_class'] == 'Unique': + raise SynthesizerInputError( + "The constraint class 'Unique' is not currently supported for multi-table" + ' synthesizers. Please remove the constraint for this synthesizer.' + ) + + if self._fitted: + warnings.warn( + "For these constraints to take effect, please refit the synthesizer using 'fit'." + ) + + for constraint in constraints: + constraint = deepcopy(constraint) + synthesizer = self._table_synthesizers[constraint.pop('table_name')] + synthesizer._data_processor.add_constraints([constraint]) + + def get_constraints(self): + """Get constraints of the synthesizer. + + Returns: + list: + List of dictionaries describing the constraints of the synthesizer. + """ + constraints = [] + for table_name, synthesizer in self._table_synthesizers.items(): + for constraint in synthesizer.get_constraints(): + constraint['table_name'] = table_name + constraints.append(constraint) + + return constraints diff --git a/tests/unit/multi_table/test_base.py b/tests/unit/multi_table/test_base.py index a85f098c3..2a6c05824 100644 --- a/tests/unit/multi_table/test_base.py +++ b/tests/unit/multi_table/test_base.py @@ -6,6 +6,7 @@ import pandas as pd import pytest +from sdv.errors import SynthesizerInputError from sdv.multi_table.base import BaseMultiTableSynthesizer from sdv.single_table.copulas import GaussianCopulaSynthesizer from sdv.single_table.errors import InvalidDataError @@ -659,3 +660,111 @@ def test_get_learned_distributions_raises_an_error(self): ) with pytest.raises(ValueError, match=error_msg): instance.get_learned_distributions('nesreca') + + def test_add_constraint_warning(self): + """Test a warning is raised when the synthesizer had already been fitted.""" + # Setup + metadata = get_multi_table_metadata() + instance = BaseMultiTableSynthesizer(metadata) + instance._fitted = True + + # Run and Assert + warn_msg = ( + "For these constraints to take effect, please refit the synthesizer using 'fit'." + ) + with pytest.warns(UserWarning, match=warn_msg): + instance.add_constraints([]) + + def test_add_constraints(self): + """Test a list of constraints can be added to the synthesizer.""" + # Setup + metadata = get_multi_table_metadata() + instance = BaseMultiTableSynthesizer(metadata) + positive_constraint = { + 'constraint_class': 'Positive', + 'table_name': 'nesreca', + 'constraint_parameters': { + 'column_name': 'id_nesreca', + 'strict_boundaries': True + } + } + negative_constraint = { + 'constraint_class': 'Negative', + 'table_name': 'oseba', + 'constraint_parameters': { + 'column_name': 'id_nesreca', + 'strict_boundaries': False + } + } + + # Run + instance.add_constraints([positive_constraint, negative_constraint]) + + # Assert + positive_constraint = { + 'constraint_class': 'Positive', + 'constraint_parameters': { + 'column_name': 'id_nesreca', + 'strict_boundaries': True + } + } + negative_constraint = { + 'constraint_class': 'Negative', + 'constraint_parameters': { + 'column_name': 'id_nesreca', + 'strict_boundaries': False + } + } + assert instance._table_synthesizers['nesreca'].get_constraints() == [positive_constraint] + assert instance._table_synthesizers['oseba'].get_constraints() == [negative_constraint] + + def test_add_constraints_unique(self): + """Test an error is raised when a ``Unique`` constraint is passed.""" + # Setup + metadata = get_multi_table_metadata() + instance = BaseMultiTableSynthesizer(metadata) + unique_constraint = { + 'constraint_class': 'Unique', + 'table_name': 'oseba', + 'constraint_parameters': { + 'column_name': 'id_nesreca', + } + } + + # Run and Assert + err_msg = re.escape( + "The constraint class 'Unique' is not currently supported for multi-table" + ' synthesizers. Please remove the constraint for this synthesizer.' + ) + with pytest.raises(SynthesizerInputError, match=err_msg): + instance.add_constraints([unique_constraint]) + + def test_get_constraints(self): + """Test a list of constraints is returned by the method.""" + # Setup + metadata = get_multi_table_metadata() + instance = BaseMultiTableSynthesizer(metadata) + positive_constraint = { + 'constraint_class': 'Positive', + 'table_name': 'nesreca', + 'constraint_parameters': { + 'column_name': 'id_nesreca', + 'strict_boundaries': True + } + } + negative_constraint = { + 'constraint_class': 'Negative', + 'table_name': 'oseba', + 'constraint_parameters': { + 'column_name': 'id_nesreca', + 'strict_boundaries': False + } + } + constraints = [positive_constraint, negative_constraint] + instance.add_constraints(constraints) + + # Run + output = instance.get_constraints() + + # Assert + assert output == constraints diff --git a/tests/unit/single_table/test_base.py b/tests/unit/single_table/test_base.py index 7e726e56f..3a8410434 100644 --- a/tests/unit/single_table/test_base.py +++ b/tests/unit/single_table/test_base.py @@ -1785,3 +1785,87 @@ def test_load(self): assert instance.metadata._sequence_key is None assert instance.metadata._sequence_index is None assert instance.metadata._version == 'SINGLE_TABLE_V1' + + def test_add_constraint_warning(self): + """Test a warning is raised when the synthesizer had already been fitted.""" + # Setup + metadata = SingleTableMetadata() + instance = BaseSingleTableSynthesizer(metadata) + instance._fitted = True + + # Run and Assert + warn_msg = ( + "For these constraints to take effect, please refit the synthesizer using 'fit'." + ) + with pytest.warns(UserWarning, match=warn_msg): + instance.add_constraints([]) + + def test_add_constraints(self): + """Test a list of constraints can be added to the synthesizer.""" + # Setup + metadata = SingleTableMetadata() + metadata.add_column('col', sdtype='numerical') + instance = BaseSingleTableSynthesizer(metadata) + positive_constraint = { + 'constraint_class': 'Positive', + 'constraint_parameters': { + 'column_name': 'col', + 'strict_boundaries': True + } + } + negative_constraint = { + 'constraint_class': 'Negative', + 'constraint_parameters': { + 'column_name': 'col', + 'strict_boundaries': False + } + } + + # Run + instance.add_constraints([positive_constraint, negative_constraint]) + + # Assert + positive_constraint = { + 'constraint_class': 'Positive', + 'constraint_parameters': { + 'column_name': 'col', + 'strict_boundaries': True + } + } + negative_constraint = { + 'constraint_class': 'Negative', + 'constraint_parameters': { + 'column_name': 'col', + 'strict_boundaries': False + } + } + assert instance.get_constraints() == [positive_constraint, negative_constraint] + + def test_get_constraints(self): + """Test a list of constraints is returned by the method.""" + # Setup + metadata = SingleTableMetadata() + metadata.add_column('col', sdtype='numerical') + instance = BaseSingleTableSynthesizer(metadata) + positive_constraint = { + 'constraint_class': 'Positive', + 'constraint_parameters': { + 'column_name': 'col', + 'strict_boundaries': True + } + } + negative_constraint = { + 'constraint_class': 'Negative', + 'constraint_parameters': { + 'column_name': 'col', + 'strict_boundaries': False + } + } + constraints = [positive_constraint, negative_constraint] + instance.add_constraints(constraints) + + # Run + output = instance.get_constraints() + + # Assert + assert output == constraints