Skip to content

Commit

Permalink
Add constraint methods to BaseMultiTableSynthesizer (#1178)
Browse files Browse the repository at this point in the history
* Address warnings issue

* Add tests to single table + address feedback

* Update err msg

* Fix typo
  • Loading branch information
fealho authored Jan 20, 2023
1 parent bf6cc6f commit 129d3ba
Show file tree
Hide file tree
Showing 3 changed files with 240 additions and 0 deletions.
47 changes: 47 additions & 0 deletions sdv/multi_table/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import numpy as np
import pandas as pd

from sdv.errors import SynthesizerInputError
from sdv.single_table.copulas import GaussianCopulaSynthesizer
from sdv.single_table.errors import InvalidDataError

Expand Down Expand Up @@ -359,3 +360,49 @@ def get_learned_distributions(self, table_name):
"""
synthesizer = self._table_synthesizers[table_name]
return synthesizer.get_learned_distributions()

def add_constraints(self, constraints):
"""Add constraints to the synthesizer.
Args:
constraints (list):
List of constraints described as dictionaries in the following format:
* ``constraint_class``: Name of the constraint to apply.
* ``table_name``: Name of the table where to apply the constraint.
* ``constraint_parameters``: A dictionary with the constraint parameters.
Raises:
SynthesizerInputError:
Raises when the ``Unique`` constraint is passed.
"""
for constraint in constraints:
if constraint['constraint_class'] == 'Unique':
raise SynthesizerInputError(
"The constraint class 'Unique' is not currently supported for multi-table"
' synthesizers. Please remove the constraint for this synthesizer.'
)

if self._fitted:
warnings.warn(
"For these constraints to take effect, please refit the synthesizer using 'fit'."
)

for constraint in constraints:
constraint = deepcopy(constraint)
synthesizer = self._table_synthesizers[constraint.pop('table_name')]
synthesizer._data_processor.add_constraints([constraint])

def get_constraints(self):
"""Get constraints of the synthesizer.
Returns:
list:
List of dictionaries describing the constraints of the synthesizer.
"""
constraints = []
for table_name, synthesizer in self._table_synthesizers.items():
for constraint in synthesizer.get_constraints():
constraint['table_name'] = table_name
constraints.append(constraint)

return constraints
109 changes: 109 additions & 0 deletions tests/unit/multi_table/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import pandas as pd
import pytest

from sdv.errors import SynthesizerInputError
from sdv.multi_table.base import BaseMultiTableSynthesizer
from sdv.single_table.copulas import GaussianCopulaSynthesizer
from sdv.single_table.errors import InvalidDataError
Expand Down Expand Up @@ -659,3 +660,111 @@ def test_get_learned_distributions_raises_an_error(self):
)
with pytest.raises(ValueError, match=error_msg):
instance.get_learned_distributions('nesreca')

def test_add_constraint_warning(self):
"""Test a warning is raised when the synthesizer had already been fitted."""
# Setup
metadata = get_multi_table_metadata()
instance = BaseMultiTableSynthesizer(metadata)
instance._fitted = True

# Run and Assert
warn_msg = (
"For these constraints to take effect, please refit the synthesizer using 'fit'."
)
with pytest.warns(UserWarning, match=warn_msg):
instance.add_constraints([])

def test_add_constraints(self):
"""Test a list of constraints can be added to the synthesizer."""
# Setup
metadata = get_multi_table_metadata()
instance = BaseMultiTableSynthesizer(metadata)
positive_constraint = {
'constraint_class': 'Positive',
'table_name': 'nesreca',
'constraint_parameters': {
'column_name': 'id_nesreca',
'strict_boundaries': True
}
}
negative_constraint = {
'constraint_class': 'Negative',
'table_name': 'oseba',
'constraint_parameters': {
'column_name': 'id_nesreca',
'strict_boundaries': False
}
}

# Run
instance.add_constraints([positive_constraint, negative_constraint])

# Assert
positive_constraint = {
'constraint_class': 'Positive',
'constraint_parameters': {
'column_name': 'id_nesreca',
'strict_boundaries': True
}
}
negative_constraint = {
'constraint_class': 'Negative',
'constraint_parameters': {
'column_name': 'id_nesreca',
'strict_boundaries': False
}
}
assert instance._table_synthesizers['nesreca'].get_constraints() == [positive_constraint]
assert instance._table_synthesizers['oseba'].get_constraints() == [negative_constraint]

def test_add_constraints_unique(self):
"""Test an error is raised when a ``Unique`` constraint is passed."""
# Setup
metadata = get_multi_table_metadata()
instance = BaseMultiTableSynthesizer(metadata)
unique_constraint = {
'constraint_class': 'Unique',
'table_name': 'oseba',
'constraint_parameters': {
'column_name': 'id_nesreca',
}
}

# Run and Assert
err_msg = re.escape(
"The constraint class 'Unique' is not currently supported for multi-table"
' synthesizers. Please remove the constraint for this synthesizer.'
)
with pytest.raises(SynthesizerInputError, match=err_msg):
instance.add_constraints([unique_constraint])

def test_get_constraints(self):
"""Test a list of constraints is returned by the method."""
# Setup
metadata = get_multi_table_metadata()
instance = BaseMultiTableSynthesizer(metadata)
positive_constraint = {
'constraint_class': 'Positive',
'table_name': 'nesreca',
'constraint_parameters': {
'column_name': 'id_nesreca',
'strict_boundaries': True
}
}
negative_constraint = {
'constraint_class': 'Negative',
'table_name': 'oseba',
'constraint_parameters': {
'column_name': 'id_nesreca',
'strict_boundaries': False
}
}
constraints = [positive_constraint, negative_constraint]
instance.add_constraints(constraints)

# Run
output = instance.get_constraints()

# Assert
assert output == constraints
84 changes: 84 additions & 0 deletions tests/unit/single_table/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -1785,3 +1785,87 @@ def test_load(self):
assert instance.metadata._sequence_key is None
assert instance.metadata._sequence_index is None
assert instance.metadata._version == 'SINGLE_TABLE_V1'

def test_add_constraint_warning(self):
"""Test a warning is raised when the synthesizer had already been fitted."""
# Setup
metadata = SingleTableMetadata()
instance = BaseSingleTableSynthesizer(metadata)
instance._fitted = True

# Run and Assert
warn_msg = (
"For these constraints to take effect, please refit the synthesizer using 'fit'."
)
with pytest.warns(UserWarning, match=warn_msg):
instance.add_constraints([])

def test_add_constraints(self):
"""Test a list of constraints can be added to the synthesizer."""
# Setup
metadata = SingleTableMetadata()
metadata.add_column('col', sdtype='numerical')
instance = BaseSingleTableSynthesizer(metadata)
positive_constraint = {
'constraint_class': 'Positive',
'constraint_parameters': {
'column_name': 'col',
'strict_boundaries': True
}
}
negative_constraint = {
'constraint_class': 'Negative',
'constraint_parameters': {
'column_name': 'col',
'strict_boundaries': False
}
}

# Run
instance.add_constraints([positive_constraint, negative_constraint])

# Assert
positive_constraint = {
'constraint_class': 'Positive',
'constraint_parameters': {
'column_name': 'col',
'strict_boundaries': True
}
}
negative_constraint = {
'constraint_class': 'Negative',
'constraint_parameters': {
'column_name': 'col',
'strict_boundaries': False
}
}
assert instance.get_constraints() == [positive_constraint, negative_constraint]

def test_get_constraints(self):
"""Test a list of constraints is returned by the method."""
# Setup
metadata = SingleTableMetadata()
metadata.add_column('col', sdtype='numerical')
instance = BaseSingleTableSynthesizer(metadata)
positive_constraint = {
'constraint_class': 'Positive',
'constraint_parameters': {
'column_name': 'col',
'strict_boundaries': True
}
}
negative_constraint = {
'constraint_class': 'Negative',
'constraint_parameters': {
'column_name': 'col',
'strict_boundaries': False
}
}
constraints = [positive_constraint, negative_constraint]
instance.add_constraints(constraints)

# Run
output = instance.get_constraints()

# Assert
assert output == constraints

0 comments on commit 129d3ba

Please sign in to comment.