Skip to content

Commit

Permalink
add test for uncertainty prediction function
Browse files Browse the repository at this point in the history
Signed-off-by: Fabian Degen <fabian.degen@mytum.de>
  • Loading branch information
Fabian Degen authored and degenfabian committed Nov 3, 2024
1 parent d666a61 commit 8fcf7c4
Showing 1 changed file with 37 additions and 0 deletions.
37 changes: 37 additions & 0 deletions python/interpret-core/tests/glassbox/ebm/test_ebm.py
Original file line number Diff line number Diff line change
Expand Up @@ -1167,3 +1167,40 @@ def test_ebm_scale():
assert len(clf.bagged_intercept_) == len(clf.bagged_scores_[0])
assert len(clf.standard_deviations_) == 2
assert len(clf.bin_weights_) == 2


def test_ebm_uncertainty():
data = synthetic_classification()
X = data["full"]["X"]
y = data["full"]["y"]

clf = ExplainableBoostingClassifier(
outer_bags=5,
random_state=42,
)
clf.fit(X, y)

result = clf.pred_from_base_models_with_uncertainty(X)
assert result.shape == (len(X), 2), "Should return (n_samples, 2) shape"

clf2 = ExplainableBoostingClassifier(outer_bags=5, random_state=42)
clf2.fit(X, y)
result_same_seed = clf2.pred_from_base_models_with_uncertainty(X)
assert np.array_equal(
result,
result_same_seed,
), "Results should be deterministic with same random seed"

X_far = X.copy()
X_far["A"] = X_far["A"] + 10.0 # Shift first feature far from training data
preds_far = clf.pred_from_base_models_with_uncertainty(X_far)

assert np.mean(preds_far[:, 1]) > np.mean(
result[:, 1]
), "Uncertainty should be higher for points far from training data"

uncertainties = result[:, 1]
assert np.all(uncertainties >= 0), "Uncertainties should be non-negative"
assert not np.all(
uncertainties == uncertainties[0]
), "Different samples should have different uncertainties"

0 comments on commit 8fcf7c4

Please sign in to comment.