Skip to content

Commit

Permalink
ENHANCEMENT: Add a prediction function that also calculates uncertain…
Browse files Browse the repository at this point in the history
…ty in the prediction (#584)

* add prediction function that also calculates uncertainty in the prediction

Signed-off-by: Fabian Degen <fabian.degen@mytum.de>

* Add error message when max_bins is not equal to max_interaction_bins when calling uncertainty prediction function

Signed-off-by: Fabian Degen <fabian.degen@mytum.de>

* use ebm_predict_scores to simplify processing in pred_from_base_models_with_uncertainty

Signed-off-by: Fabian Degen <fabian.degen@mytum.de>

* fix to actually iterate through bags not scores in uncertainty prediction function

Signed-off-by: Fabian Degen <fabian.degen@mytum.de>

* add test for uncertainty prediction function

Signed-off-by: Fabian Degen <fabian.degen@mytum.de>

* adapt test to be more generalizable

Signed-off-by: Fabian Degen <fabian.degen@mytum.de>

---------

Signed-off-by: Fabian Degen <fabian.degen@mytum.de>
Co-authored-by: Fabian Degen <fabian.degen@mytum.de>
  • Loading branch information
degenfabian and Fabian Degen authored Nov 4, 2024
1 parent e1182f1 commit 3fdcab5
Show file tree
Hide file tree
Showing 3 changed files with 72 additions and 1 deletion.
38 changes: 38 additions & 0 deletions python/interpret-core/interpret/glassbox/_ebm/_ebm.py
Original file line number Diff line number Diff line change
Expand Up @@ -2255,6 +2255,44 @@ def scale(self, term, factor):

return self

def pred_from_base_models_with_uncertainty(self, instances):
"""Gets raw scores and uncertainties from the bagged base models.
Generates predictions by averaging outputs across all bagged models, and estimates
uncertainty using the standard deviation of predictions across bags.
Args:
instances: ndarray of shape (n_samples, n_features)
The input samples to predict on.
Returns:
ndarray of shape (n_samples, 2)
First column contains mean predictions
Second column contains uncertainties
"""
check_is_fitted(self, "has_fitted_")

X, n_samples = preclean_X(
instances, self.feature_names_in_, self.feature_types_in_
)
preds_per_bag = np.zeros((n_samples, len(self.bagged_intercept_)))
# Get predictions from each bagged model
for bag_index in range(len(self.bagged_intercept_)):
# Use slices from bagged parameters for this specific model
scores = ebm_predict_scores(
X=X,
n_samples=n_samples,
feature_names_in=self.feature_names_in_,
feature_types_in=self.feature_types_in_,
bins=self.bins_,
intercept=self.bagged_intercept_[bag_index],
term_scores=[scores[bag_index] for scores in self.bagged_scores_],
term_features=self.term_features_,
)
preds_per_bag[:, bag_index] = scores

# Calculate mean predictions and uncertainties
return np.c_[np.mean(preds_per_bag, axis=1), np.std(preds_per_bag, axis=1)]

def _multinomialize(self, passthrough=0.0):
check_is_fitted(self, "has_fitted_")

Expand Down
3 changes: 2 additions & 1 deletion python/interpret-core/interpret/utils/_preprocessor.py
Original file line number Diff line number Diff line change
Expand Up @@ -550,6 +550,7 @@ def construct_bins(
privacy_bounds=None,
):
is_mains = True

for max_bins in max_bins_leveled:
preprocessor = EBMPreprocessor(
feature_names_given,
Expand All @@ -566,7 +567,6 @@ def construct_bins(
)

seed = increment_seed(seed)

preprocessor.fit(X, y, sample_weight)
if is_mains:
is_mains = False
Expand All @@ -582,6 +582,7 @@ def construct_bins(
missing_val_counts = preprocessor.missing_val_counts_
unique_val_counts = preprocessor.unique_val_counts_
noise_scale = preprocessor.noise_scale_

else:
if feature_names_in != preprocessor.feature_names_in_:
msg = "Mismatched feature_names"
Expand Down
32 changes: 32 additions & 0 deletions python/interpret-core/tests/glassbox/ebm/test_ebm.py
Original file line number Diff line number Diff line change
Expand Up @@ -1167,3 +1167,35 @@ def test_ebm_scale():
assert len(clf.bagged_intercept_) == len(clf.bagged_scores_[0])
assert len(clf.standard_deviations_) == 2
assert len(clf.bin_weights_) == 2


def test_ebm_uncertainty():
data = synthetic_classification()
X = data["full"]["X"]
y = data["full"]["y"]

clf = ExplainableBoostingClassifier(
outer_bags=5,
random_state=42,
)
clf.fit(X, y)

result = clf.pred_from_base_models_with_uncertainty(X)
assert result.shape == (len(X), 2), "Should return (n_samples, 2) shape"

clf2 = ExplainableBoostingClassifier(outer_bags=5, random_state=42)
clf2.fit(X, y)
result_same_seed = clf2.pred_from_base_models_with_uncertainty(X)
assert np.array_equal(
result,
result_same_seed,
), "Results should be deterministic with same random seed"

mean_predictions = result[:, 0]
assert np.all(np.isfinite(mean_predictions)), "All predictions should be finite"

uncertainties = result[:, 1]
assert np.all(uncertainties >= 0), "Uncertainties should be non-negative"
assert not np.all(
uncertainties == uncertainties[0]
), "Different samples should have different uncertainties"

0 comments on commit 3fdcab5

Please sign in to comment.