diff --git a/sdv/multi_table/base.py b/sdv/multi_table/base.py index 978683464..3a14d30a3 100644 --- a/sdv/multi_table/base.py +++ b/sdv/multi_table/base.py @@ -370,20 +370,10 @@ def get_learned_distributions(self, table_name): def _validate_constraints(self, constraints): for constraint_dict in constraints: - params = {'constraint_class', 'constraint_parameters', 'table_name'} - keys = constraint_dict.keys() - missing_params = params - keys - if missing_params: + if 'table_name' not in constraint_dict.keys(): raise SynthesizerInputError( - f'A constraint is missing required parameters {missing_params}. ' - 'Please add these parameters to your constraint definition.' - ) - - extra_params = keys - params - if extra_params: - raise SynthesizerInputError( - f'Unrecognized constraint parameter {extra_params}. ' - 'Please remove these parameters from your constraint definition.' + "A constraint is missing required parameter 'table_name'. " + 'Please add this parameter to your constraint definition.' ) if constraint_dict['constraint_class'] == 'Unique': diff --git a/tests/unit/data_processing/test_data_processor.py b/tests/unit/data_processing/test_data_processor.py index fe371376d..84e73f3a3 100644 --- a/tests/unit/data_processing/test_data_processor.py +++ b/tests/unit/data_processing/test_data_processor.py @@ -490,7 +490,7 @@ def test__validate_constraint_dict_key_error(self, mock_constraint): """Validate that an ``InvalidConstraintsError`` is raised when the class is not found.""" # Setup constraint_dict = { - 'constraint_class': 'Positiv', + 'constraint_class': 'Positive', 'constraint_parameters': {'column_name': 'col1'} } mock_constraint._get_class_from_dict.side_effect = [KeyError] @@ -499,7 +499,7 @@ def test__validate_constraint_dict_key_error(self, mock_constraint): dp = DataProcessor(metadata) # Run and Assert - error_msg = re.escape("Invalid constraint class ('Positiv').") + error_msg = re.escape("Invalid constraint class ('Positive').") with pytest.raises(InvalidConstraintsError, match=error_msg): dp._validate_constraint_dict(constraint_dict) diff --git a/tests/unit/multi_table/test_base.py b/tests/unit/multi_table/test_base.py index f3adf78cc..8aed387e5 100644 --- a/tests/unit/multi_table/test_base.py +++ b/tests/unit/multi_table/test_base.py @@ -772,44 +772,19 @@ def test_get_constraints(self): # Assert assert output == constraints - def test_add_constraints_missing_parameters(self): - """Test error raised when required params are missing.""" + def test_add_constraints_missing_table_name(self): + """Test error raised when ``table_name`` is missing.""" # Setup data = pd.DataFrame({'col': [1, 2, 3]}) metadata = MultiTableMetadata() metadata.detect_table_from_dataframe('table', data) - constraint = {'constraint_class': 'Inequality', 'table_name': 'test'} + constraint = {'constraint_class': 'Inequality'} model = BaseMultiTableSynthesizer(metadata) # Run and Assert err_msg = re.escape( - "A constraint is missing required parameters {'constraint_parameters'}. " - 'Please add these parameters to your constraint definition.' - ) - with pytest.raises(SynthesizerInputError, match=err_msg): - model.add_constraints([constraint]) - - def test_add_constraints_invalid_parameters(self): - """Test error raised when invalid params are passed.""" - # Setup - data = pd.DataFrame({'col': [1, 2, 3]}) - metadata = MultiTableMetadata() - metadata.detect_table_from_dataframe('table', data) - constraint = { - 'constraint_class': 'Inequality', - 'table_name': 'test', - 'constraint_parameters': { - 'low_column_name': 'col', - 'high_column_name': 'col' - }, - 'invalid': 42 - } - model = BaseMultiTableSynthesizer(metadata) - - # Run and Assert - err_msg = re.escape( - "Unrecognized constraint parameter {'invalid'}. " - 'Please remove these parameters from your constraint definition.' + "A constraint is missing required parameter 'table_name'. " + 'Please add this parameter to your constraint definition.' ) with pytest.raises(SynthesizerInputError, match=err_msg): model.add_constraints([constraint])