diff --git a/sdv/constraints/base.py b/sdv/constraints/base.py index 3eca9e68f..3265b57d2 100644 --- a/sdv/constraints/base.py +++ b/sdv/constraints/base.py @@ -12,6 +12,7 @@ from rdt import HyperTransformer from sdv.constraints.errors import MissingConstraintColumnError +from sdv.errors import ConstraintsNotMetError LOGGER = logging.getLogger(__name__) @@ -120,13 +121,17 @@ class Constraint(metaclass=ConstraintMeta): def _identity(self, table_data): return table_data + def _identity_with_validation(self, table_data): + self._validate_data_on_constraint(table_data) + return table_data + def __init__(self, handling_strategy, fit_columns_model=False): self.fit_columns_model = fit_columns_model if handling_strategy == 'transform': self.filter_valid = self._identity elif handling_strategy == 'reject_sampling': self.rebuild_columns = () - self.transform = self._identity + self.transform = self._identity_with_validation self.reverse_transform = self._identity elif handling_strategy != 'all': raise ValueError('Unknown handling strategy: {}'.format(handling_strategy)) @@ -220,6 +225,31 @@ def _sample_constraint_columns(self, table_data): sampled_data = pd.concat(all_sampled_rows, ignore_index=True) return sampled_data + def _validate_data_on_constraint(self, table_data): + """Make sure the given data is valid for the given constraints. + + Args: + data (pandas.DataFrame): + Table data. + + Raises: + ConstraintsNotMetError: + If the table data is not valid for the provided constraints. + """ + if set(self.constraint_columns).issubset(table_data.columns.values): + is_valid_data = self.is_valid(table_data) + if not is_valid_data.all(): + constraint_data = table_data[list(self.constraint_columns)] + invalid_rows = constraint_data[~is_valid_data] + err_msg = ( + f"Data is not valid for the '{self.__class__.__name__}' constraint:\n" + f'{invalid_rows[:5]}' + ) + if len(invalid_rows) > 5: + err_msg += f'\n+{len(invalid_rows) - 5} more' + + raise ConstraintsNotMetError(err_msg) + def _validate_constraint_columns(self, table_data): """Validate the columns in ``table_data``. @@ -277,6 +307,7 @@ def transform(self, table_data): pandas.DataFrame: Input data unmodified. """ + self._validate_data_on_constraint(table_data) table_data = self._validate_constraint_columns(table_data) return self._transform(table_data) diff --git a/sdv/constraints/errors.py b/sdv/constraints/errors.py index 4e4da55b4..ae5a592ed 100644 --- a/sdv/constraints/errors.py +++ b/sdv/constraints/errors.py @@ -3,3 +3,7 @@ class MissingConstraintColumnError(Exception): """Error to use when constraint is provided a table with missing columns.""" + + +class MultipleConstraintsErrors(Exception): + """Error used to represent a list of constraint errors.""" diff --git a/sdv/constraints/tabular.py b/sdv/constraints/tabular.py index 41b33a9ec..6a1599959 100644 --- a/sdv/constraints/tabular.py +++ b/sdv/constraints/tabular.py @@ -19,7 +19,9 @@ on the other columns of the table. * Between: Ensure that the value in one column is always between the values of two other columns/scalars. + * Rounding: Round a column based on the specified number of digits. * OneHotEncoding: Ensure the rows of the specified columns are one hot encoded. + * Unique: Ensure that each value for a specified column/group of columns is unique. """ import operator @@ -1135,6 +1137,7 @@ class Unique(Constraint): def __init__(self, columns): self.columns = columns if isinstance(columns, list) else [columns] + self.constraint_columns = tuple(self.columns) super().__init__(handling_strategy='reject_sampling', fit_columns_model=False) def is_valid(self, table_data): diff --git a/sdv/metadata/table.py b/sdv/metadata/table.py index 7851de16f..7d854aafd 100644 --- a/sdv/metadata/table.py +++ b/sdv/metadata/table.py @@ -10,8 +10,7 @@ from faker import Faker from sdv.constraints.base import Constraint -from sdv.constraints.errors import MissingConstraintColumnError -from sdv.errors import ConstraintsNotMetError +from sdv.constraints.errors import MissingConstraintColumnError, MultipleConstraintsErrors from sdv.metadata.errors import MetadataError, MetadataNotFittedError from sdv.metadata.utils import strings_from_regex @@ -443,9 +442,15 @@ def _get_transformers(self, dtypes): return transformers def _fit_transform_constraints(self, data): + errors = [] for constraint in self._constraints: - data = constraint.fit_transform(data) + try: + data = constraint.fit_transform(data) + except Exception as e: + errors.append(e) + if errors: + raise MultipleConstraintsErrors('\n' + '\n\n'.join(map(str, errors))) return data def _fit_hyper_transformer(self, data, extra_columns): @@ -610,25 +615,6 @@ def _transform_constraints(self, data, on_missing_column='error'): return data - def _validate_data_on_constraints(self, data): - """Make sure the given data is valid for the given constraints. - - Args: - data (pandas.DataFrame): - Table data. - - Returns: - None - - Raises: - ConstraintsNotMetError: - If the table data is not valid for the provided constraints. - """ - for constraint in self._constraints: - if set(constraint.constraint_columns).issubset(data.columns.values): - if not constraint.is_valid(data).all(): - raise ConstraintsNotMetError('Data is not valid for the given constraints') - def transform(self, data, on_missing_column='error'): """Transform the given data. @@ -643,10 +629,6 @@ def transform(self, data, on_missing_column='error'): Returns: pandas.DataFrame: Transformed data. - - Raises: - ConstraintsNotMetError: - If the table data is not valid for the provided constraints. """ if not self.fitted: raise MetadataNotFittedError() @@ -655,8 +637,6 @@ def transform(self, data, on_missing_column='error'): LOGGER.debug('Anonymizing table %s', self.name) data = self._anonymize(data[fields]) - self._validate_data_on_constraints(data) - LOGGER.debug('Transforming constraints for table %s', self.name) data = self._transform_constraints(data, on_missing_column) diff --git a/tests/integration/test_constraints.py b/tests/integration/test_constraints.py index 091a52a60..d1609e875 100644 --- a/tests/integration/test_constraints.py +++ b/tests/integration/test_constraints.py @@ -1,4 +1,12 @@ -from sdv.constraints import ColumnFormula, FixedCombinations, GreaterThan +import re + +import pandas as pd +import pytest + +from sdv.constraints import ( + Between, ColumnFormula, FixedCombinations, GreaterThan, Negative, OneHotEncoding, Positive, + Rounding, Unique) +from sdv.constraints.errors import MultipleConstraintsErrors from sdv.demo import load_tabular_demo from sdv.tabular import GaussianCopula @@ -40,3 +48,79 @@ def test_constraints(tmpdir): gc.save(tmpdir / 'test.pkl') gc = gc.load(tmpdir / 'test.pkl') gc.sample(10) + + +def test_failing_constraints(): + data = pd.DataFrame({ + 'a': [0, 0, 0, 0, 0, 0, 0], + 'b': [1, -1, 2, -2, 3, -3, 5], + 'c': [-1, -1, -1, -1, -1, -1, -1], + 'd': [1, -1, 2, -2, 3, -3, 5], + 'e': [1, 2, 3, 4, 5, 6, 'a'], + 'f': [1, 1, 2, 2, 3, 3, -1], + 'g': [1, 0, 1, 0, 0, 1, 0], + 'h': [1, 1, 1, 0, 0, 10, 0], + 'i': [1, 1, 1, 1, 1, 1, 1] + }) + + constraints = [ + GreaterThan('a', 'b'), + Positive('c'), + Negative('d'), + Rounding('e', 2), + Between('f', 0, 3), + OneHotEncoding(['g', 'h']), + Unique('i') + ] + gc = GaussianCopula(constraints=constraints) + + err_msg = re.escape( + "\nunsupported operand type(s) for -: 'str' and 'str'" + '\n' + "\nData is not valid for the 'OneHotEncoding' constraint:" + '\n g h' + '\n0 1 1' + '\n2 1 1' + '\n3 0 0' + '\n4 0 0' + '\n5 1 10' + '\n+1 more' + '\n' + "\nData is not valid for the 'Unique' constraint:" + '\n i' + '\n1 1' + '\n2 1' + '\n3 1' + '\n4 1' + '\n5 1' + '\n+1 more' + '\n' + "\nData is not valid for the 'GreaterThan' constraint:" + '\n a b' + '\n1 0 -1' + '\n3 0 -2' + '\n5 0 -3' + '\n' + "\nData is not valid for the 'Positive' constraint:" + '\n c' + '\n0 -1' + '\n1 -1' + '\n2 -1' + '\n3 -1' + '\n4 -1' + '\n+2 more' + '\n' + "\nData is not valid for the 'Negative' constraint:" + '\n d' + '\n0 1' + '\n2 2' + '\n4 3' + '\n6 5' + '\n' + "\nData is not valid for the 'Between' constraint:" + '\n f' + '\n6 -1' + ) + + with pytest.raises(MultipleConstraintsErrors, match=err_msg): + gc.fit(data) diff --git a/tests/unit/constraints/test_base.py b/tests/unit/constraints/test_base.py index b6ba1a14b..14d188e72 100644 --- a/tests/unit/constraints/test_base.py +++ b/tests/unit/constraints/test_base.py @@ -11,6 +11,7 @@ from sdv.constraints.base import Constraint, _get_qualified_name, get_subclasses, import_object from sdv.constraints.errors import MissingConstraintColumnError from sdv.constraints.tabular import ColumnFormula, FixedCombinations +from sdv.errors import ConstraintsNotMetError def test__get_qualified_name_class(): @@ -182,7 +183,7 @@ def test___init___reject_sampling(self): # Asserts assert instance.filter_valid != instance._identity - assert instance.transform == instance._identity + assert instance.transform == instance._identity_with_validation assert instance.reverse_transform == instance._identity def test___init___all(self): @@ -298,25 +299,106 @@ def test_fit_trains_column_model(self, ht_mock, gm_mock): assert len(calls) == 1 pd.testing.assert_frame_equal(args[0], table_data) + def test__validate_data_on_constraints(self): + """Test the ``_validate_data_on_constraint`` method. + + Expect that the method calls ``is_valid`` when the constraint columns + are in the given data. + + Input: + - Table data + Output: + - None + Side Effects: + - No error + """ + # Setup + data = pd.DataFrame({ + 'a': [0, 1, 2], + 'b': [3, 4, 5] + }, index=[0, 1, 2]) + constraint = Constraint(handling_strategy='transform') + constraint.constraint_columns = ['a', 'b'] + constraint.is_valid = Mock() + + # Run + constraint._validate_data_on_constraint(data) + + # Assert + constraint.is_valid.assert_called_once_with(data) + + def test__validate_data_on_constraints_invalid_input(self): + """Test the ``_validate_data_on_constraint`` method. + + Expect that the method raises an error when the constraint columns + are in the given data and the ``is_valid`` returns False for any row. + + Input: + - Table data contains an invalid row + Output: + - None + Side Effects: + - A ``ConstraintsNotMetError`` is thrown + """ + # Setup + data = pd.DataFrame({ + 'a': [0, 1, 2], + 'b': [3, 4, 5] + }, index=[0, 1, 2]) + constraint = Constraint(handling_strategy='transform') + constraint.constraint_columns = ['a', 'b'] + constraint.is_valid = Mock(return_value=pd.Series([True, False, True])) + + # Run / Assert + with pytest.raises(ConstraintsNotMetError): + constraint._validate_data_on_constraint(data) + + def test__validate_data_on_constraints_missing_cols(self): + """Test the ``_validate_data_on_constraint`` method. + + Expect that the method doesn't do anything when the columns are not in the given data. + + Input: + - Table data that is missing a constraint column + Output: + - None + Side Effects: + - No error + """ + # Setup + data = pd.DataFrame({ + 'a': [0, 1, 2], + 'b': [3, 4, 5] + }, index=[0, 1, 2]) + constraint = Constraint(handling_strategy='transform') + constraint.constraint_columns = ['a', 'b', 'c'] + constraint.is_valid = Mock() + + # Run + constraint._validate_data_on_constraint(data) + + # Assert + assert not constraint.is_valid.called + def test_transform(self): """Test the ``Constraint.transform`` method. - It is an identity method for completion, to be optionally - overwritten by subclasses. + When no constraints are passed, it behaves like an identity method, + to be optionally overwritten by subclasses. The ``Constraint.transform`` method is expected to: - Return the input data unmodified. Input: - - Anything + - a DataFrame Output: - Input """ # Run instance = Constraint(handling_strategy='transform') - output = instance.transform('input') + output = instance.transform(pd.DataFrame({'col': ['input']})) # Assert - assert output == 'input' + pd.testing.assert_frame_equal(output, pd.DataFrame({'col': ['input']})) def test_transform_calls__transform(self): """Test that the ``Constraint.transform`` method calls ``_transform``. @@ -415,7 +497,7 @@ def test_transform_model_enabled_some_columns_missing(self): transformed_data = instance.transform(data) # Assert - expected_tranformed_data = pd.DataFrame([[1, 2, 3]], columns=['b', 'c', 'a']) + expected_transformed_data = pd.DataFrame([[1, 2, 3]], columns=['b', 'c', 'a']) expected_result = pd.DataFrame([ [5, 1, 2], [6, 3, 4] @@ -425,8 +507,8 @@ def test_transform_model_enabled_some_columns_missing(self): instance._columns_model.sample.assert_any_call(num_rows=1, conditions={'b': 1}) instance._columns_model.sample.assert_any_call(num_rows=1, conditions={'b': 3}) reverse_transform_calls = instance._hyper_transformer.reverse_transform.mock_calls - pd.testing.assert_frame_equal(reverse_transform_calls[0][1][0], expected_tranformed_data) - pd.testing.assert_frame_equal(reverse_transform_calls[1][1][0], expected_tranformed_data) + pd.testing.assert_frame_equal(reverse_transform_calls[0][1][0], expected_transformed_data) + pd.testing.assert_frame_equal(reverse_transform_calls[1][1][0], expected_transformed_data) pd.testing.assert_frame_equal(transformed_data, expected_result) def test_transform_model_enabled_reject_sampling(self): diff --git a/tests/unit/constraints/test_tabular.py b/tests/unit/constraints/test_tabular.py index 70ae8c64f..1ba780da4 100644 --- a/tests/unit/constraints/test_tabular.py +++ b/tests/unit/constraints/test_tabular.py @@ -5366,7 +5366,7 @@ def test___init__(self): # Assert assert instance.columns == ['a', 'b'] assert instance.fit_columns_model is False - assert instance.transform == instance._identity + assert instance.transform == instance._identity_with_validation assert instance.reverse_transform == instance._identity def test___init__one_column(self): diff --git a/tests/unit/metadata/test_dataset.py b/tests/unit/metadata/test_dataset.py index a1f8a88a9..7aeeef13a 100644 --- a/tests/unit/metadata/test_dataset.py +++ b/tests/unit/metadata/test_dataset.py @@ -892,7 +892,7 @@ def test_add_table_with_constraints(self): - Constraints for the given table Side Effects: - An entry is added to the metadata for the provided table, which contains - the given fields and constrants. + the given fields and constraints. """ # Setup metadata = Mock(spec_set=Metadata) diff --git a/tests/unit/metadata/test_table.py b/tests/unit/metadata/test_table.py index 8efe95eaf..e463eb6db 100644 --- a/tests/unit/metadata/test_table.py +++ b/tests/unit/metadata/test_table.py @@ -8,7 +8,6 @@ from sdv.constraints.base import Constraint from sdv.constraints.errors import MissingConstraintColumnError -from sdv.errors import ConstraintsNotMetError from sdv.metadata import Table @@ -582,93 +581,6 @@ def test__transform_constraints_drops_columns(self): }, index=[0, 1, 2]) assert result.equals(expected_result) - def test__validate_data_on_constraints(self): - """Test the ``Table._validate_data_on_constraints`` method. - - Expect that the method returns True when the constraint columns are in the given data, - and the constraint.is_valid method returns True. - - Input: - - Table data - Output: - - None - Side Effects: - - No error - """ - # Setup - data = pd.DataFrame({ - 'a': [0, 1, 2], - 'b': [3, 4, 5] - }, index=[0, 1, 2]) - constraint_mock = Mock() - constraint_mock.is_valid.return_value = pd.Series([True, True, True]) - constraint_mock.constraint_columns = ['a', 'b'] - table_mock = Mock() - table_mock._constraints = [constraint_mock] - - # Run - result = Table._validate_data_on_constraints(table_mock, data) - - # Assert - assert result is None - - def test__validate_data_on_constraints_invalid_input(self): - """Test the ``Table._validate_data_on_constraints`` method. - - Expect that the method returns False when the constraint columns are in the given data, - and the constraint.is_valid method returns False for any row. - - Input: - - Table data contains an invalid row - Output: - - None - Side Effects: - - A ConstraintsNotMetError is thrown - """ - # Setup - data = pd.DataFrame({ - 'a': [0, 1, 2], - 'b': [3, 4, 5] - }, index=[0, 1, 2]) - constraint_mock = Mock() - constraint_mock.is_valid.return_value = pd.Series([True, False, True]) - constraint_mock.constraint_columns = ['a', 'b'] - table_mock = Mock() - table_mock._constraints = [constraint_mock] - - # Run and assert - with pytest.raises(ConstraintsNotMetError): - Table._validate_data_on_constraints(table_mock, data) - - def test__validate_data_on_constraints_missing_cols(self): - """Test the ``Table._validate_data_on_constraints`` method. - - Expect that the method returns True when the constraint columns are not - in the given data. - - Input: - - Table data that is missing a constraint column - Output: - - None - Side Effects: - - No error - """ - # Setup - data = pd.DataFrame({ - 'a': [0, 1, 2], - 'b': [3, 4, 5] - }, index=[0, 1, 2]) - constraint_mock = Mock() - constraint_mock.constraint_columns = ['a', 'b', 'c'] - table_mock = Mock() - table_mock._constraints = [constraint_mock] - - # Run - result = Table._validate_data_on_constraints(table_mock, data) - - # Assert - assert result is None - def test_from_dict_min_max(self): """Test the ``Table.from_dict`` method.