Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ENHANCEMENT: Add a prediction function that also calculates uncertainty in the prediction #584

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 38 additions & 0 deletions python/interpret-core/interpret/glassbox/_ebm/_ebm.py
Original file line number Diff line number Diff line change
Expand Up @@ -2255,6 +2255,44 @@ def scale(self, term, factor):

return self

def pred_from_base_models_with_uncertainty(self, instances):
"""Gets raw scores and uncertainties from the bagged base models.
Generates predictions by averaging outputs across all bagged models, and estimates
uncertainty using the standard deviation of predictions across bags.

Args:
instances: ndarray of shape (n_samples, n_features)
The input samples to predict on.

Returns:
ndarray of shape (n_samples, 2)
First column contains mean predictions
Second column contains uncertainties
"""
check_is_fitted(self, "has_fitted_")

X, n_samples = preclean_X(
instances, self.feature_names_in_, self.feature_types_in_
)
preds_per_bag = np.zeros((n_samples, len(self.bagged_intercept_)))
# Get predictions from each bagged model
for bag_index in range(len(self.bagged_intercept_)):
# Use slices from bagged parameters for this specific model
scores = ebm_predict_scores(
X=X,
n_samples=n_samples,
feature_names_in=self.feature_names_in_,
feature_types_in=self.feature_types_in_,
bins=self.bins_,
intercept=self.bagged_intercept_[bag_index],
term_scores=[scores[bag_index] for scores in self.bagged_scores_],
term_features=self.term_features_,
)
preds_per_bag[:, bag_index] = scores

# Calculate mean predictions and uncertainties
return np.c_[np.mean(preds_per_bag, axis=1), np.std(preds_per_bag, axis=1)]

def _multinomialize(self, passthrough=0.0):
check_is_fitted(self, "has_fitted_")

Expand Down
3 changes: 2 additions & 1 deletion python/interpret-core/interpret/utils/_preprocessor.py
Original file line number Diff line number Diff line change
Expand Up @@ -550,6 +550,7 @@ def construct_bins(
privacy_bounds=None,
):
is_mains = True

for max_bins in max_bins_leveled:
preprocessor = EBMPreprocessor(
feature_names_given,
Expand All @@ -566,7 +567,6 @@ def construct_bins(
)

seed = increment_seed(seed)

preprocessor.fit(X, y, sample_weight)
if is_mains:
is_mains = False
Expand All @@ -582,6 +582,7 @@ def construct_bins(
missing_val_counts = preprocessor.missing_val_counts_
unique_val_counts = preprocessor.unique_val_counts_
noise_scale = preprocessor.noise_scale_

else:
if feature_names_in != preprocessor.feature_names_in_:
msg = "Mismatched feature_names"
Expand Down
32 changes: 32 additions & 0 deletions python/interpret-core/tests/glassbox/ebm/test_ebm.py
Original file line number Diff line number Diff line change
Expand Up @@ -1167,3 +1167,35 @@ def test_ebm_scale():
assert len(clf.bagged_intercept_) == len(clf.bagged_scores_[0])
assert len(clf.standard_deviations_) == 2
assert len(clf.bin_weights_) == 2


def test_ebm_uncertainty():
data = synthetic_classification()
X = data["full"]["X"]
y = data["full"]["y"]

clf = ExplainableBoostingClassifier(
outer_bags=5,
random_state=42,
)
clf.fit(X, y)

result = clf.pred_from_base_models_with_uncertainty(X)
assert result.shape == (len(X), 2), "Should return (n_samples, 2) shape"

clf2 = ExplainableBoostingClassifier(outer_bags=5, random_state=42)
clf2.fit(X, y)
result_same_seed = clf2.pred_from_base_models_with_uncertainty(X)
assert np.array_equal(
result,
result_same_seed,
), "Results should be deterministic with same random seed"

mean_predictions = result[:, 0]
assert np.all(np.isfinite(mean_predictions)), "All predictions should be finite"

uncertainties = result[:, 1]
assert np.all(uncertainties >= 0), "Uncertainties should be non-negative"
assert not np.all(
uncertainties == uncertainties[0]
), "Different samples should have different uncertainties"