diff --git a/tests/python_package_test/test_engine.py b/tests/python_package_test/test_engine.py index 6ffec8cee7d9..29d0a6831028 100644 --- a/tests/python_package_test/test_engine.py +++ b/tests/python_package_test/test_engine.py @@ -1252,8 +1252,8 @@ def generate_trainset_for_monotone_constraints_tests(x3_to_category=True): return trainset -@pytest.mark.parametrize("test_with_interaction_constraints", [True, False]) -def test_monotone_constraints(test_with_interaction_constraints): +@pytest.mark.parametrize("test_with_categorical_variable", [True, False]) +def test_monotone_constraints(test_with_categorical_variable): def is_increasing(y): return (np.diff(y) >= 0.0).all() @@ -1316,10 +1316,12 @@ def has_interaction(treef): return not has_interaction_flag.any() - for test_with_categorical_variable in [True, False]: - trainset = generate_trainset_for_monotone_constraints_tests( - test_with_categorical_variable - ) + trainset = generate_trainset_for_monotone_constraints_tests( + test_with_categorical_variable + ) + for test_with_interaction_constraints in [True, False]: + error_msg = ("Model not correctly constrained " + f"(test_with_interaction_constraints={test_with_interaction_constraints})") for monotone_constraints_method in ["basic", "intermediate", "advanced"]: params = { "min_data": 20, @@ -1333,7 +1335,7 @@ def has_interaction(treef): constrained_model = lgb.train(params, trainset) assert is_correctly_constrained( constrained_model, test_with_categorical_variable - ) + ), error_msg if test_with_interaction_constraints: feature_sets = [["Column_0"], ["Column_1"], "Column_2"] assert are_interactions_enforced(constrained_model, feature_sets) @@ -1399,8 +1401,9 @@ def test_monotone_penalty_max(): } unconstrained_model = lgb.train(params_unconstrained_model, trainset_unconstrained_model, 10) - unconstrained_model_predictions = unconstrained_model.\ - predict(x3_negatively_correlated_with_y.reshape(-1, 1)) + unconstrained_model_predictions = unconstrained_model.predict( + x3_negatively_correlated_with_y.reshape(-1, 1) + ) for monotone_constraints_method in ["basic", "intermediate", "advanced"]: params_constrained_model["monotone_constraints_method"] = monotone_constraints_method