Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add constraint methods to BaseMultiTableSynthesizer #1178

Merged
merged 4 commits into from
Jan 20, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 47 additions & 0 deletions sdv/multi_table/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import numpy as np
import pandas as pd

from sdv.errors import SynthesizerInputError
from sdv.single_table.copulas import GaussianCopulaSynthesizer
from sdv.single_table.errors import InvalidDataError

Expand Down Expand Up @@ -359,3 +360,49 @@ def get_learned_distributions(self, table_name):
"""
synthesizer = self._table_synthesizers[table_name]
return synthesizer.get_learned_distributions()

def add_constraints(self, constraints):
"""Add constraints to the synthesizer.

Args:
constraints (list):
List of constraints described as dictionaries in the following format:
* ``constraint_class``: Name of the constraint to apply.
* ``table_name``: Name of the table where to apply the constraint.
* ``constraint_parameters``: A dictionary with the constraint parameters.

Raises:
SynthesizerInputError:
Raises when the ``Unique`` constraint is passed.
"""
for constraint in constraints:
if constraint['constraint_class'] == 'Unique':
raise SynthesizerInputError(
"The constraint class 'Unique' is not currently supported for multi-table"
' synthesizers. Please remove the constraint for this synthesizer.'
)

if self._fitted:
warnings.warn(
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Adding the warning here so it only shows once. If I used the SingleTable's add_constraints warning, it would show for every constraint in the list.

I'll add a test for the warning after you thumbs up this.

"For these constraints to take effect, please refit the synthesizer using 'fit'."
)

for constraint in constraints:
constraint = deepcopy(constraint)
synthesizer = self._table_synthesizers[constraint.pop('table_name')]
synthesizer._data_processor.add_constraints([constraint])

def get_constraints(self):
"""Get constraints of the synthesizer.

Returns:
list:
List of dictionaries describing the constraints of the synthesizer.
"""
constraints = []
for table_name, synthesizer in self._table_synthesizers.items():
for constraint in synthesizer.get_constraints():
constraint['table_name'] = table_name
constraints.append(constraint)

return constraints
109 changes: 109 additions & 0 deletions tests/unit/multi_table/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import pandas as pd
import pytest

from sdv.errors import SynthesizerInputError
from sdv.multi_table.base import BaseMultiTableSynthesizer
from sdv.single_table.copulas import GaussianCopulaSynthesizer
from sdv.single_table.errors import InvalidDataError
Expand Down Expand Up @@ -659,3 +660,111 @@ def test_get_learned_distributions_raises_an_error(self):
)
with pytest.raises(ValueError, match=error_msg):
instance.get_learned_distributions('nesreca')

def test_add_constraint_warning(self):
"""Test a warning is raised when the synthesizer had already been fitted."""
# Setup
metadata = get_multi_table_metadata()
instance = BaseMultiTableSynthesizer(metadata)
instance._fitted = True

# Run and Assert
warn_msg = (
"For these constraints to take effect, please refit the synthesizer using 'fit'."
)
with pytest.warns(UserWarning, match=warn_msg):
instance.add_constraints([])

def test_add_constraints(self):
"""Test a list of constraints can be added to the synthesizer."""
# Setup
metadata = get_multi_table_metadata()
instance = BaseMultiTableSynthesizer(metadata)
positive_constraint = {
'constraint_class': 'Positive',
'table_name': 'nesreca',
'constraint_parameters': {
'column_name': 'id_nesreca',
'strict_boundaries': True
}
}
negative_constraint = {
'constraint_class': 'Negative',
'table_name': 'oseba',
'constraint_parameters': {
'column_name': 'id_nesreca',
'strict_boundaries': False
}
}

# Run
instance.add_constraints([positive_constraint, negative_constraint])

# Assert
positive_constraint = {
'constraint_class': 'Positive',
'constraint_parameters': {
'column_name': 'id_nesreca',
'strict_boundaries': True
}
}
negative_constraint = {
'constraint_class': 'Negative',
'constraint_parameters': {
'column_name': 'id_nesreca',
'strict_boundaries': False
}
}
assert instance._table_synthesizers['nesreca'].get_constraints() == [positive_constraint]
assert instance._table_synthesizers['oseba'].get_constraints() == [negative_constraint]

def test_add_constraints_unique(self):
"""Test an error is raised when a ``Unique`` constraint is passed."""
# Setup
metadata = get_multi_table_metadata()
instance = BaseMultiTableSynthesizer(metadata)
unique_constraint = {
'constraint_class': 'Unique',
'table_name': 'oseba',
'constraint_parameters': {
'column_name': 'id_nesreca',
}
}

# Run and Assert
err_msg = re.escape(
"The constraint class 'Unique' is not currently supported for multi-table"
' synthesizers. Please remove the constraint for this synthesizer.'
)
with pytest.raises(SynthesizerInputError, match=err_msg):
instance.add_constraints([unique_constraint])

def test_get_constraints(self):
"""Test a list of constraints is returned by the method."""
# Setup
metadata = get_multi_table_metadata()
instance = BaseMultiTableSynthesizer(metadata)
positive_constraint = {
'constraint_class': 'Positive',
'table_name': 'nesreca',
'constraint_parameters': {
'column_name': 'id_nesreca',
'strict_boundaries': True
}
}
negative_constraint = {
'constraint_class': 'Negative',
'table_name': 'oseba',
'constraint_parameters': {
'column_name': 'id_nesreca',
'strict_boundaries': False
}
}
constraints = [positive_constraint, negative_constraint]
instance.add_constraints(constraints)

# Run
output = instance.get_constraints()

# Assert
assert output == constraints
84 changes: 84 additions & 0 deletions tests/unit/single_table/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -1785,3 +1785,87 @@ def test_load(self):
assert instance.metadata._sequence_key is None
assert instance.metadata._sequence_index is None
assert instance.metadata._version == 'SINGLE_TABLE_V1'

def test_add_constraint_warning(self):
"""Test a warning is raised when the synthesizer had already been fitted."""
# Setup
metadata = SingleTableMetadata()
instance = BaseSingleTableSynthesizer(metadata)
instance._fitted = True

# Run and Assert
warn_msg = (
"For these constraints to take effect, please refit the synthesizer using 'fit'."
)
with pytest.warns(UserWarning, match=warn_msg):
instance.add_constraints([])

def test_add_constraints(self):
"""Test a list of constraints can be added to the synthesizer."""
# Setup
metadata = SingleTableMetadata()
metadata.add_column('col', sdtype='numerical')
instance = BaseSingleTableSynthesizer(metadata)
positive_constraint = {
'constraint_class': 'Positive',
'constraint_parameters': {
'column_name': 'col',
'strict_boundaries': True
}
}
negative_constraint = {
'constraint_class': 'Negative',
'constraint_parameters': {
'column_name': 'col',
'strict_boundaries': False
}
}

# Run
instance.add_constraints([positive_constraint, negative_constraint])

# Assert
positive_constraint = {
'constraint_class': 'Positive',
'constraint_parameters': {
'column_name': 'col',
'strict_boundaries': True
}
}
negative_constraint = {
'constraint_class': 'Negative',
'constraint_parameters': {
'column_name': 'col',
'strict_boundaries': False
}
}
assert instance.get_constraints() == [positive_constraint, negative_constraint]

def test_get_constraints(self):
"""Test a list of constraints is returned by the method."""
# Setup
metadata = SingleTableMetadata()
metadata.add_column('col', sdtype='numerical')
instance = BaseSingleTableSynthesizer(metadata)
positive_constraint = {
'constraint_class': 'Positive',
'constraint_parameters': {
'column_name': 'col',
'strict_boundaries': True
}
}
negative_constraint = {
'constraint_class': 'Negative',
'constraint_parameters': {
'column_name': 'col',
'strict_boundaries': False
}
}
constraints = [positive_constraint, negative_constraint]
instance.add_constraints(constraints)

# Run
output = instance.get_constraints()

# Assert
assert output == constraints