Skip to content

Commit

Permalink
Add test case + print correctly
Browse files Browse the repository at this point in the history
  • Loading branch information
fealho committed May 24, 2022
1 parent 5a35b3f commit 30276b1
Show file tree
Hide file tree
Showing 4 changed files with 84 additions and 13 deletions.
13 changes: 7 additions & 6 deletions sdv/constraints/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -277,13 +277,14 @@ def _validate_data_on_constraint(self, table_data):
if set(self.constraint_columns).issubset(table_data.columns.values):
is_valid = self.is_valid(table_data)
if not is_valid.all():
invalid_rows = table_data[~is_valid]
err_msg = [
f"Data is not valid for the '{self.__class__.__name__}' constraint:\n",
f'{invalid_rows[:5]}\n'
]
constraint_data = table_data[list(self.constraint_columns)]
invalid_rows = constraint_data[~is_valid]
err_msg = (
f"Data is not valid for the '{self.__class__.__name__}' constraint:\n"
f'{invalid_rows[:5]}'
)
if len(invalid_rows) > 5:
err_msg.append(f'+{len(invalid_rows) - 5} more')
err_msg += f'\n+{len(invalid_rows) - 5} more'

raise ConstraintsNotMetError(err_msg)

Expand Down
7 changes: 2 additions & 5 deletions sdv/constraints/tabular.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,9 @@
* Negative: Ensure that the values in given columns are always negative.
* ColumnFormula: Compute the value of a column based on applying a formula
on the other columns of the table.
* Rounding: Round a column based on the specified number of digits.
* Between: Ensure that the value in one column is always between the values
of two other columns/scalars.
* Rounding: Round a column based on the specified number of digits.
* OneHotEncoding: Ensure the rows of the specified columns are one hot encoded.
* Unique: Ensure that each value for a specified column/group of columns is unique.
"""
Expand Down Expand Up @@ -1153,7 +1153,4 @@ def is_valid(self, table_data):
pandas.Series:
Whether each row is valid.
"""
print('table_data')
data = table_data.groupby(self.columns, dropna=False).cumcount() == 0
print(data)
return data
return table_data.groupby(self.columns, dropna=False).cumcount() == 0
2 changes: 1 addition & 1 deletion sdv/metadata/table.py
Original file line number Diff line number Diff line change
Expand Up @@ -450,7 +450,7 @@ def _fit_transform_constraints(self, data):
errors.append(e)

if errors:
raise MultipleConstraintsErrors(errors)
raise MultipleConstraintsErrors('\n' + '\n\n'.join(map(str, errors)))
return data

def _fit_hyper_transformer(self, data, extra_columns):
Expand Down
75 changes: 74 additions & 1 deletion tests/integration/test_constraints.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,12 @@
from sdv.constraints import ColumnFormula, FixedCombinations, GreaterThan
import re

import pandas as pd
import pytest

from sdv.constraints import (
Between, ColumnFormula, FixedCombinations, GreaterThan, Negative, OneHotEncoding, Positive,
Rounding)
from sdv.constraints.errors import MultipleConstraintsErrors
from sdv.demo import load_tabular_demo
from sdv.tabular import GaussianCopula

Expand Down Expand Up @@ -40,3 +48,68 @@ def test_constraints(tmpdir):
gc.save(tmpdir / 'test.pkl')
gc = gc.load(tmpdir / 'test.pkl')
gc.sample(10)


def test_failing_constraints():
data = pd.DataFrame({
'a': [0, 0, 0, 0, 0, 0, 0],
'b': [1, -1, 2, -2, 3, -3, 5],
'c': [-1, -1, -1, -1, -1, -1, -1],
'd': [1, -1, 2, -2, 3, -3, 5],
'e': [1, 2, 3, 4, 5, 6, 'a'],
'f': [1, 1, 2, 2, 3, 3, -1],
'g': [1, 0, 1, 0, 0, 1, 0],
'h': [1, 1, 1, 0, 0, 10, 0],
})

constraints = [
GreaterThan('a', 'b'),
Positive('c'),
Negative('d'),
Rounding('e', 2),
Between('f', 0, 3),
OneHotEncoding(['g', 'h']),
]
gc = GaussianCopula(constraints=constraints)

err_msg = re.escape(
"\nunsupported operand type(s) for -: 'str' and 'str'"
'\n'
"\nData is not valid for the 'OneHotEncoding' constraint:"
'\n g h'
'\n0 1 1'
'\n2 1 1'
'\n3 0 0'
'\n4 0 0'
'\n5 1 10'
'\n+1 more'
'\n'
"\nData is not valid for the 'GreaterThan' constraint:"
'\n a b'
'\n1 0 -1'
'\n3 0 -2'
'\n5 0 -3'
'\n'
"\nData is not valid for the 'Positive' constraint:"
'\n c'
'\n0 -1'
'\n1 -1'
'\n2 -1'
'\n3 -1'
'\n4 -1'
'\n+2 more'
'\n'
"\nData is not valid for the 'Negative' constraint:"
'\n d'
'\n0 1'
'\n2 2'
'\n4 3'
'\n6 5'
'\n'
"\nData is not valid for the 'Between' constraint:"
'\n f'
'\n6 -1'
)

with pytest.raises(MultipleConstraintsErrors, match=err_msg):
gc.fit(data)

0 comments on commit 30276b1

Please sign in to comment.