Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Raise constraint errors together + misc. #807

Merged
merged 11 commits into from
May 26, 2022
33 changes: 32 additions & 1 deletion sdv/constraints/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from rdt import HyperTransformer

from sdv.constraints.errors import MissingConstraintColumnError
from sdv.errors import ConstraintsNotMetError

LOGGER = logging.getLogger(__name__)

Expand Down Expand Up @@ -120,13 +121,17 @@ class Constraint(metaclass=ConstraintMeta):
def _identity(self, table_data):
return table_data

def _identity_with_validation(self, table_data):
self._validate_data_on_constraint(table_data)
return table_data

def __init__(self, handling_strategy, fit_columns_model=False):
self.fit_columns_model = fit_columns_model
if handling_strategy == 'transform':
self.filter_valid = self._identity
elif handling_strategy == 'reject_sampling':
self.rebuild_columns = ()
self.transform = self._identity
self.transform = self._identity_with_validation
self.reverse_transform = self._identity
elif handling_strategy != 'all':
raise ValueError('Unknown handling strategy: {}'.format(handling_strategy))
Expand Down Expand Up @@ -220,6 +225,31 @@ def _sample_constraint_columns(self, table_data):
sampled_data = pd.concat(all_sampled_rows, ignore_index=True)
return sampled_data

def _validate_data_on_constraint(self, table_data):
fealho marked this conversation as resolved.
Show resolved Hide resolved
"""Make sure the given data is valid for the given constraints.

Args:
data (pandas.DataFrame):
Table data.

Raises:
ConstraintsNotMetError:
If the table data is not valid for the provided constraints.
"""
if set(self.constraint_columns).issubset(table_data.columns.values):
is_valid_data = self.is_valid(table_data)
if not is_valid_data.all():
constraint_data = table_data[list(self.constraint_columns)]
invalid_rows = constraint_data[~is_valid_data]
err_msg = (
f"Data is not valid for the '{self.__class__.__name__}' constraint:\n"
f'{invalid_rows[:5]}'
)
if len(invalid_rows) > 5:
err_msg += f'\n+{len(invalid_rows) - 5} more'

raise ConstraintsNotMetError(err_msg)

def _validate_constraint_columns(self, table_data):
"""Validate the columns in ``table_data``.

Expand Down Expand Up @@ -277,6 +307,7 @@ def transform(self, table_data):
pandas.DataFrame:
Input data unmodified.
"""
self._validate_data_on_constraint(table_data)
table_data = self._validate_constraint_columns(table_data)
return self._transform(table_data)

Expand Down
4 changes: 4 additions & 0 deletions sdv/constraints/errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,3 +3,7 @@

class MissingConstraintColumnError(Exception):
"""Error to use when constraint is provided a table with missing columns."""


class MultipleConstraintsErrors(Exception):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should this be Error instaed of Errors ?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think Errors is clearer, since it is a list of multiple errors

"""Error used to represent a list of constraint errors."""
3 changes: 3 additions & 0 deletions sdv/constraints/tabular.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,9 @@
on the other columns of the table.
* Between: Ensure that the value in one column is always between the values
of two other columns/scalars.
* Rounding: Round a column based on the specified number of digits.
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unrelated to my changes, we just forgot to add these here when creating the constraints.

* OneHotEncoding: Ensure the rows of the specified columns are one hot encoded.
* Unique: Ensure that each value for a specified column/group of columns is unique.
"""

import operator
Expand Down Expand Up @@ -1135,6 +1137,7 @@ class Unique(Constraint):

def __init__(self, columns):
self.columns = columns if isinstance(columns, list) else [columns]
self.constraint_columns = tuple(self.columns)
amontanez24 marked this conversation as resolved.
Show resolved Hide resolved
super().__init__(handling_strategy='reject_sampling', fit_columns_model=False)

def is_valid(self, table_data):
Expand Down
36 changes: 8 additions & 28 deletions sdv/metadata/table.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,7 @@
from faker import Faker

from sdv.constraints.base import Constraint
from sdv.constraints.errors import MissingConstraintColumnError
from sdv.errors import ConstraintsNotMetError
from sdv.constraints.errors import MissingConstraintColumnError, MultipleConstraintsErrors
from sdv.metadata.errors import MetadataError, MetadataNotFittedError
from sdv.metadata.utils import strings_from_regex

Expand Down Expand Up @@ -443,9 +442,15 @@ def _get_transformers(self, dtypes):
return transformers

def _fit_transform_constraints(self, data):
errors = []
for constraint in self._constraints:
data = constraint.fit_transform(data)
try:
data = constraint.fit_transform(data)
except Exception as e:
errors.append(e)

if errors:
raise MultipleConstraintsErrors('\n' + '\n\n'.join(map(str, errors)))
return data

def _fit_hyper_transformer(self, data, extra_columns):
Expand Down Expand Up @@ -610,25 +615,6 @@ def _transform_constraints(self, data, on_missing_column='error'):

return data

def _validate_data_on_constraints(self, data):
"""Make sure the given data is valid for the given constraints.

Args:
data (pandas.DataFrame):
Table data.

Returns:
None

Raises:
ConstraintsNotMetError:
If the table data is not valid for the provided constraints.
"""
for constraint in self._constraints:
if set(constraint.constraint_columns).issubset(data.columns.values):
if not constraint.is_valid(data).all():
raise ConstraintsNotMetError('Data is not valid for the given constraints')

def transform(self, data, on_missing_column='error'):
"""Transform the given data.

Expand All @@ -643,10 +629,6 @@ def transform(self, data, on_missing_column='error'):
Returns:
pandas.DataFrame:
Transformed data.

Raises:
ConstraintsNotMetError:
If the table data is not valid for the provided constraints.
amontanez24 marked this conversation as resolved.
Show resolved Hide resolved
"""
if not self.fitted:
raise MetadataNotFittedError()
Expand All @@ -655,8 +637,6 @@ def transform(self, data, on_missing_column='error'):
LOGGER.debug('Anonymizing table %s', self.name)
data = self._anonymize(data[fields])

self._validate_data_on_constraints(data)

LOGGER.debug('Transforming constraints for table %s', self.name)
data = self._transform_constraints(data, on_missing_column)

Expand Down
86 changes: 85 additions & 1 deletion tests/integration/test_constraints.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,12 @@
from sdv.constraints import ColumnFormula, FixedCombinations, GreaterThan
import re

import pandas as pd
import pytest

from sdv.constraints import (
Between, ColumnFormula, FixedCombinations, GreaterThan, Negative, OneHotEncoding, Positive,
Rounding, Unique)
from sdv.constraints.errors import MultipleConstraintsErrors
from sdv.demo import load_tabular_demo
from sdv.tabular import GaussianCopula

Expand Down Expand Up @@ -40,3 +48,79 @@ def test_constraints(tmpdir):
gc.save(tmpdir / 'test.pkl')
gc = gc.load(tmpdir / 'test.pkl')
gc.sample(10)


def test_failing_constraints():
data = pd.DataFrame({
'a': [0, 0, 0, 0, 0, 0, 0],
'b': [1, -1, 2, -2, 3, -3, 5],
'c': [-1, -1, -1, -1, -1, -1, -1],
'd': [1, -1, 2, -2, 3, -3, 5],
'e': [1, 2, 3, 4, 5, 6, 'a'],
'f': [1, 1, 2, 2, 3, 3, -1],
'g': [1, 0, 1, 0, 0, 1, 0],
'h': [1, 1, 1, 0, 0, 10, 0],
'i': [1, 1, 1, 1, 1, 1, 1]
})

constraints = [
GreaterThan('a', 'b'),
Positive('c'),
Negative('d'),
Rounding('e', 2),
Between('f', 0, 3),
OneHotEncoding(['g', 'h']),
Unique('i')
]
gc = GaussianCopula(constraints=constraints)

err_msg = re.escape(
"\nunsupported operand type(s) for -: 'str' and 'str'"
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This PR only implements the "pretty print" for the general case where the data doesn't conform with constraint.is_valid(). If some other error shows up, it will just be printed as usual.

'\n'
"\nData is not valid for the 'OneHotEncoding' constraint:"
'\n g h'
'\n0 1 1'
'\n2 1 1'
'\n3 0 0'
'\n4 0 0'
'\n5 1 10'
'\n+1 more'
'\n'
"\nData is not valid for the 'Unique' constraint:"
'\n i'
'\n1 1'
'\n2 1'
'\n3 1'
'\n4 1'
'\n5 1'
'\n+1 more'
'\n'
"\nData is not valid for the 'GreaterThan' constraint:"
'\n a b'
'\n1 0 -1'
'\n3 0 -2'
'\n5 0 -3'
'\n'
"\nData is not valid for the 'Positive' constraint:"
'\n c'
'\n0 -1'
'\n1 -1'
'\n2 -1'
'\n3 -1'
'\n4 -1'
'\n+2 more'
'\n'
"\nData is not valid for the 'Negative' constraint:"
'\n d'
'\n0 1'
'\n2 2'
'\n4 3'
'\n6 5'
'\n'
"\nData is not valid for the 'Between' constraint:"
'\n f'
'\n6 -1'
)

with pytest.raises(MultipleConstraintsErrors, match=err_msg):
gc.fit(data)
Loading