Skip to content

Commit

Permalink
Merge pull request #280 from KhiopsML/273-eliminate-is_fitted_-sklear…
Browse files Browse the repository at this point in the history
…n-attribute
  • Loading branch information
folmos-at-orange authored Jan 8, 2025
2 parents 4e9aba0 + 0b1671d commit d093ae8
Show file tree
Hide file tree
Showing 3 changed files with 49 additions and 26 deletions.
50 changes: 25 additions & 25 deletions khiops/sklearn/estimators.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
RegressorMixin,
TransformerMixin,
)
from sklearn.exceptions import NotFittedError
from sklearn.utils.multiclass import type_of_target
from sklearn.utils.validation import assert_all_finite, check_is_fitted, column_or_1d

Expand Down Expand Up @@ -302,7 +303,7 @@ def _undefine_estimator_attributes(self):

def _get_main_dictionary(self):
"""Returns the model's main Khiops dictionary"""
assert self.model_ is not None, "Model dictionary domain not available."
self._assert_is_fitted()
return self.model_.get_dictionary(self.model_main_dictionary_name_)

def export_report_file(self, report_file_path):
Expand All @@ -318,16 +319,14 @@ def export_report_file(self, report_file_path):
`ValueError`
When the instance is not fitted.
"""
if not self.is_fitted_:
raise ValueError(f"{self.__class__.__name__} not fitted yet.")
check_is_fitted(self)
if self.model_report_ is None:
raise ValueError("Report not available (imported model?).")
self.model_report_.write_khiops_json_file(report_file_path)

def export_dictionary_file(self, dictionary_file_path):
"""Export the model's Khiops dictionary file (.kdic)"""
if not self.is_fitted_:
raise ValueError(f"{self.__class__.__name__} not fitted yet.")
check_is_fitted(self)
self.model_.export_khiops_dictionary_file(dictionary_file_path)

def _import_model(self, kdic_path):
Expand Down Expand Up @@ -384,13 +383,19 @@ def fit(self, X, y=None, **kwargs):
# If on "fitted" state then:
# - self.model_ must be a DictionaryDomain
# - self.model_report_ must be a KhiopsJSONObject
if hasattr(self, "is_fitted_") and self.is_fitted_:
assert hasattr(self, "model_") and isinstance(
self.model_, kh.DictionaryDomain
)
assert hasattr(self, "model_report_") and isinstance(
self.model_report_, kh.KhiopsJSONObject
)
try:
check_is_fitted(self)
assert isinstance(self.model_, kh.DictionaryDomain)
assert isinstance(self.model_report_, kh.KhiopsJSONObject)
assert isinstance(self.model_, kh.DictionaryDomain)
# Note:
# We ignore any raised NotFittedError by check_is_fitted because we are using
# the try/catch as an if/else. The code intended is
# if check_is_fitted(self):
# # asserts
# But check_is_fitted has a do-nothing or raise pattern.
except NotFittedError:
pass

return self

Expand Down Expand Up @@ -424,7 +429,6 @@ def _fit(self, ds, computation_dir, **kwargs):
and isinstance(self.model_report_, kh.KhiopsJSONObject)
):
self._fit_training_post_process(ds)
self.is_fitted_ = True
self.is_multitable_model_ = ds.is_multitable
self.n_features_in_ = ds.main_table.n_features()

Expand Down Expand Up @@ -649,6 +653,12 @@ def _create_computation_dir(self, method_name):
prefix=f"{self.__class__.__name__}_{method_name}_"
)

def _assert_is_fitted(self):
try:
check_is_fitted(self)
except NotFittedError:
raise AssertionError("Model not fitted")


# Note: scikit-learn **requires** inherit first the mixins and then other classes
class KhiopsCoclustering(ClusterMixin, KhiopsEstimator):
Expand Down Expand Up @@ -704,8 +714,6 @@ class KhiopsCoclustering(ClusterMixin, KhiopsEstimator):
Attributes
----------
is_fitted_ : bool
``True`` if the estimator is fitted.
is_multitable_model_ : bool
``True`` if the model was fitted on a multi-table dataset.
model_ : `.DictionaryDomain`
Expand Down Expand Up @@ -1152,7 +1160,6 @@ def _simplify(
# Copy relevant attributes
# Note: do not copy `model_*` attributes, that get rebuilt anyway
for attribute_name in (
"is_fitted_",
"is_multitable_model_",
"model_main_dictionary_name_",
"model_id_column",
Expand Down Expand Up @@ -1215,8 +1222,7 @@ def simplify(
A *new*, simplified `.KhiopsCoclustering` estimator instance.
"""
# Check that the estimator is fitted:
if not self.is_fitted_:
raise ValueError("Only fitted coclustering estimators can be simplified")
check_is_fitted(self)

return self._simplify(
max_preserved_information=max_preserved_information,
Expand Down Expand Up @@ -2015,8 +2021,6 @@ class KhiopsClassifier(ClassifierMixin, KhiopsPredictor):
- Importance: The geometric mean between the Level and the Weight.
is_fitted_ : bool
``True`` if the estimator is fitted.
is_multitable_model_ : bool
``True`` if the model was fitted on a multi-table dataset.
model_ : `.DictionaryDomain`
Expand Down Expand Up @@ -2097,7 +2101,7 @@ def _is_real_target_dtype_integer(self):

def _sorted_prob_variable_names(self):
"""Returns the model probability variable names in the order of self.classes_"""
assert self.is_fitted_, "Model not fit yet"
self._assert_is_fitted()

# Collect the probability variables from the model main dictionary
prob_variables = []
Expand Down Expand Up @@ -2483,8 +2487,6 @@ class KhiopsRegressor(RegressorMixin, KhiopsPredictor):
- Importance: The geometric mean between the Level and the Weight.
is_fitted_ : bool
``True`` if the estimator is fitted.
is_multitable_model_ : bool
``True`` if the model was fitted on a multi-table dataset.
model_ : `.DictionaryDomain`
Expand Down Expand Up @@ -2774,8 +2776,6 @@ class KhiopsEncoder(TransformerMixin, KhiopsSupervisedEstimator):
Level of the features evaluated by the classifier. The Level is measure of the
predictive importance of the feature taken individually. It ranges between 0 (no
predictive interest) and 1 (optimal predictive importance).
is_fitted_ : bool
``True`` if the estimator is fitted.
is_multitable_model_ : bool
``True`` if the model was fitted on a multi-table dataset.
model_ : `.DictionaryDomain`
Expand Down
1 change: 0 additions & 1 deletion tests/test_estimator_attributes.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,6 @@ def assert_attribute_values_ok(self, model, X, y):
self.assertEqual(
model.n_features_used_, len(feature_used_importances_report)
)
self.assertTrue(model.is_fitted_)

def test_classifier_attributes_monotable(self):
"""Test consistency of KhiopsClassifier's attributes with the output reports
Expand Down
24 changes: 24 additions & 0 deletions tests/test_sklearn.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,10 @@
import shutil
import unittest
import warnings
from itertools import product

import numpy as np
from sklearn.exceptions import NotFittedError
from sklearn.utils.estimator_checks import check_estimator
from sklearn.utils.validation import NotFittedError, check_is_fitted

Expand Down Expand Up @@ -2704,3 +2706,25 @@ def test_khiops_encoder_no_output_variables_implies_not_fit(self):

# Check that the encoder is not fit
self.assertNotFit(khe)

def test_export_operations_raise_when_not_fitted(self):
"""Test that export functions raise NonFittedError exceptions when non-fitted
.. note:
The standard operations (predict, predict_proba, transform, etc) are
covered by KhiopsSklearnEstimatorStandardTests.
"""
# Prepare the fixtures
export_operations = ["export_dictionary_file", "export_report_file"]
estimators = [
KhiopsClassifier(),
KhiopsRegressor(),
KhiopsEncoder(),
KhiopsCoclustering(),
]

# Execute the tests
for export_operation, estimator in product(export_operations, estimators):
with self.subTest(export_operation=export_operation, estimator=estimator):
with self.assertRaises(NotFittedError):
getattr(estimator, export_operation)("report.khj")

0 comments on commit d093ae8

Please sign in to comment.