Skip to content

Commit

Permalink
Merge branch 'cfe_classes_' into fix_r_evaluate
Browse files Browse the repository at this point in the history
  • Loading branch information
FrancescMartiEscofetQC committed Jun 14, 2024
2 parents 6a4cd07 + 32c721d commit e3df56a
Showing 1 changed file with 10 additions and 1 deletion.
11 changes: 10 additions & 1 deletion metalearners/cross_fit_estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,7 @@ class CrossFitEstimator:
_overall_estimator: _ScikitModel | None = field(init=False)
_test_indices: tuple[np.ndarray] | None = field(init=False)
_n_classes: int | None = field(init=False)
classes_: np.ndarray | None = field(init=False)

def __post_init__(self):
_validate_n_folds(self.n_folds)
Expand All @@ -115,6 +116,7 @@ def __post_init__(self):
self._overall_estimator: _ScikitModel | None = None
self._test_indices: tuple[np.ndarray] | None = None
self._n_classes: int | None = None
self.classes_: np.ndarray | None = None

def _train_overall_estimator(
self, X: Matrix, y: Matrix | Vector, fit_params: dict | None = None
Expand Down Expand Up @@ -189,7 +191,14 @@ def fit(

if is_classifier(self):
self._n_classes = len(np.unique(y))

self.classes_ = np.unique(y)
for e in self._estimators:
if set(e.classes_) != set(self.classes_): # type: ignore
raise ValueError(
"Some cross fit estimators training data had less classes than "
"the overall estimator. Please check the cv parameter. If you are "
"synchronizing the folds in a MetaLearner consider not doing it."
)
return self

def _initialize_prediction_tensor(
Expand Down

0 comments on commit e3df56a

Please sign in to comment.