Skip to content

Commit

Permalink
minimal tree gam implementation
Browse files Browse the repository at this point in the history
  • Loading branch information
csinva committed Apr 22, 2024
1 parent 0be79e0 commit 82db112
Showing 1 changed file with 226 additions and 0 deletions.
226 changes: 226 additions & 0 deletions imodels/algebraic/tree_gam_minimal.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,226 @@
from copy import deepcopy
import numpy as np
import pandas as pd
from sklearn.base import BaseEstimator
from sklearn.linear_model import ElasticNetCV, LinearRegression, RidgeCV
from sklearn.tree import DecisionTreeRegressor
from sklearn.utils.validation import check_is_fitted
from sklearn.utils import check_array
from sklearn.utils.multiclass import check_classification_targets
from sklearn.utils.validation import check_X_y
from sklearn.utils.validation import _check_sample_weight
from sklearn.ensemble import GradientBoostingClassifier, GradientBoostingRegressor
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, roc_auc_score
from tqdm import tqdm

import imodels

from sklearn.base import RegressorMixin, ClassifierMixin


class TreeGAMMinimal(BaseEstimator):
"""Tree-based GAM classifier.
Uses cyclical boosting to fit a GAM with small trees.
Simplified version of the explainable boosting machine described in https://github.com/interpretml/interpret
Only works for binary classification.
Fits a scalar bias to the mean.
"""

def __init__(
self,
n_boosting_rounds=100,
max_leaf_nodes=3,
reg_param=0.0,
learning_rate: float = 0.01,
boosting_strategy="cyclic",
validation_frac=0.15,
random_state=None,
):
"""
Params
------
n_boosting_rounds : int
Number of boosting rounds for the cyclic boosting.
max_leaf_nodes : int
Maximum number of leaf nodes for the trees in the cyclic boosting.
reg_param : float
Regularization parameter for the cyclic boosting.
learning_rate: float
Learning rate for the cyclic boosting.
boosting_strategy : str ["cyclic", "greedy"]
Whether to use cyclic boosting (cycle over features) or greedy boosting (select best feature at each step)
validation_frac: float
Fraction of data to use for early stopping.
random_state : int
Random seed.
"""
self.n_boosting_rounds = n_boosting_rounds
self.max_leaf_nodes = max_leaf_nodes
self.reg_param = reg_param
self.learning_rate = learning_rate
self.boosting_strategy = boosting_strategy
self.validation_frac = validation_frac
self.random_state = random_state

def fit(self, X, y, sample_weight=None):
X, y = check_X_y(X, y, accept_sparse=False, multi_output=False)
if isinstance(self, ClassifierMixin):
check_classification_targets(y)
self.classes_, y = np.unique(y, return_inverse=True)

sample_weight = _check_sample_weight(sample_weight, X, dtype=None)

# split into train and validation for early stopping
(
X_train,
X_val,
y_train,
y_val,
sample_weight_train,
sample_weight_val,
) = train_test_split(
X,
y,
sample_weight,
test_size=self.validation_frac,
random_state=self.random_state,
stratify=y if isinstance(self, ClassifierMixin) else None,
)

self.estimators_ = []
self.bias_ = np.mean(y)

self._cyclic_boost(
X_train,
y_train,
sample_weight_train,
X_val,
y_val,
sample_weight_val,
)

self.mse_val_ = self._calc_mse(X_val, y_val, sample_weight_val)

return self

def _cyclic_boost(
self, X_train, y_train, sample_weight_train, X_val, y_val, sample_weight_val
):
"""Apply cyclic boosting, storing trees in self.estimators_"""

residuals_train = y_train - self.predict_proba(X_train)[:, 1]
mse_val = self._calc_mse(X_val, y_val, sample_weight_val)
for _ in range(self.n_boosting_rounds):
boosting_round_ests = []
boosting_round_mses = []
feature_nums = np.arange(X_train.shape[1])
for feature_num in feature_nums:
X_ = np.zeros_like(X_train)
X_[:, feature_num] = X_train[:, feature_num]
est = DecisionTreeRegressor(
max_leaf_nodes=self.max_leaf_nodes,
random_state=self.random_state,
)
est.fit(X_, residuals_train, sample_weight=sample_weight_train)
succesfully_split_on_feature = np.all(
(est.tree_.feature[0] == feature_num) | (
est.tree_.feature[0] == -2)
)
if not succesfully_split_on_feature:
continue
if self.reg_param > 0:
est = imodels.HSTreeRegressor(
est, reg_param=self.reg_param)
self.estimators_.append(est)
residuals_train_new = (
residuals_train - self.learning_rate * est.predict(X_train)
)
if self.boosting_strategy == "cyclic":
residuals_train = residuals_train_new
elif self.boosting_strategy == "greedy":
mse_train_new = self._calc_mse(
X_train, y_train, sample_weight_train
)
# don't add each estimator for greedy
boosting_round_ests.append(
deepcopy(self.estimators_.pop()))
boosting_round_mses.append(mse_train_new)

if self.boosting_strategy == "greedy":
best_est = boosting_round_ests[np.argmin(boosting_round_mses)]
self.estimators_.append(best_est)
residuals_train = (
residuals_train - self.learning_rate *
best_est.predict(X_train)
)

# early stopping if validation error does not decrease
mse_val_new = self._calc_mse(X_val, y_val, sample_weight_val)
if mse_val_new >= mse_val:
# print("early stop!")
return
else:
mse_val = mse_val_new

def predict_proba(self, X):
X = check_array(X, accept_sparse=False, dtype=None)
check_is_fitted(self)
probs1 = np.ones(X.shape[0]) * self.bias_
for i, est in enumerate(self.estimators_):
probs1 += self.learning_rate * est.predict(X)
probs1 = np.clip(probs1, a_min=0, a_max=1)
return np.array([1 - probs1, probs1]).T

def predict(self, X):
if isinstance(self, RegressorMixin):
return self.predict_proba(X)[:, 1]
elif isinstance(self, ClassifierMixin):
return np.argmax(self.predict_proba(X), axis=1)

def _calc_mse(self, X, y, sample_weight=None):
return np.average(
np.square(y - self.predict_proba(X)[:, 1]),
weights=sample_weight,
)


class TreeGAMMinimalRegressor(TreeGAMMinimal, RegressorMixin):
...


class TreeGAMMinimalClassifier(TreeGAMMinimal, ClassifierMixin):
...


if __name__ == "__main__":
X, y, feature_names = imodels.get_clean_dataset("heart")
X, X_test, y_train, y_test = train_test_split(X, y, random_state=42)
gam = TreeGAMMinimalClassifier(
boosting_strategy="cyclic",
random_state=42,
learning_rate=0.1,
max_leaf_nodes=3,
n_boosting_rounds=100,
)
gam.fit(X, y_train)

# check roc auc score
y_pred = gam.predict_proba(X_test)[:, 1]
# print(
# "train roc:",
# roc_auc_score(y_train, gam.predict_proba(X)[:, 1]).round(3),
# )
print(f"test roc: {roc_auc_score(y_test, y_pred):.3f}")
print(f"test acc {accuracy_score(y_test, gam.predict(X_test)):.3f}")
print('\t(imb:', np.mean(y_test).round(3), ')')
# print(
# "accs",
# accuracy_score(y_train, gam.predict(X)).round(3),
# accuracy_score(y_test, gam.predict(X_test)).round(3),
# "imb",
# np.mean(y_train).round(3),
# np.mean(y_test).round(3),
# )

# # print(gam.estimators_)

0 comments on commit 82db112

Please sign in to comment.