From 15da2c64f135b8c5732b1366070e398397043b36 Mon Sep 17 00:00:00 2001 From: BenjaminBossan Date: Sun, 8 Jan 2023 11:57:20 +0100 Subject: [PATCH] Prevent CatBoost from creating folder when testing By default, CatBoost creates a "catboost_info/" folder in root when a model is fitted during testing. This is prevented now. --- skops/io/tests/test_external.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/skops/io/tests/test_external.py b/skops/io/tests/test_external.py index f9edcffc..a8e6038f 100644 --- a/skops/io/tests/test_external.py +++ b/skops/io/tests/test_external.py @@ -283,7 +283,7 @@ def trusted(self): @pytest.mark.parametrize("boosting_type", boosting_types) def test_classifier(self, catboost, cb_clf_data, trusted, boosting_type): estimator = catboost.CatBoostClassifier( - verbose=False, boosting_type=boosting_type + verbose=False, boosting_type=boosting_type, allow_writing_files=False ) loaded = loads(dumps(estimator), trusted=trusted) assert_params_equal(estimator.get_params(), loaded.get_params()) @@ -296,7 +296,7 @@ def test_classifier(self, catboost, cb_clf_data, trusted, boosting_type): @pytest.mark.parametrize("boosting_type", boosting_types) def test_regressor(self, catboost, cb_regr_data, trusted, boosting_type): estimator = catboost.CatBoostRegressor( - verbose=False, boosting_type=boosting_type + verbose=False, boosting_type=boosting_type, allow_writing_files=False ) loaded = loads(dumps(estimator), trusted=trusted) assert_params_equal(estimator.get_params(), loaded.get_params()) @@ -308,7 +308,9 @@ def test_regressor(self, catboost, cb_regr_data, trusted, boosting_type): @pytest.mark.parametrize("boosting_type", boosting_types) def test_ranker(self, catboost, cb_rank_data, trusted, boosting_type): - estimator = catboost.CatBoostRanker(verbose=False, boosting_type=boosting_type) + estimator = catboost.CatBoostRanker( + verbose=False, boosting_type=boosting_type, allow_writing_files=False + ) loaded = loads(dumps(estimator), trusted=trusted) assert_params_equal(estimator.get_params(), loaded.get_params())