Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Hyperparameter tuning for CausalForestDML #390

Merged
merged 18 commits into from
Mar 9, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
126 changes: 119 additions & 7 deletions econml/dml/causal_forest.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,21 +4,24 @@
from warnings import warn

import numpy as np
from sklearn.linear_model import LogisticRegressionCV
from sklearn.base import clone, BaseEstimator
from sklearn.preprocessing import FunctionTransformer
from sklearn.pipeline import Pipeline
from sklearn.model_selection import train_test_split
from itertools import product
from .dml import _BaseDML
from .dml import _FirstStageWrapper, _FinalWrapper
from ..sklearn_extensions.linear_model import WeightedLassoCVWrapper
from ..sklearn_extensions.model_selection import WeightedStratifiedKFold
from ..inference import NormalInferenceResults
from ..inference._inference import Inference
from sklearn.linear_model import LogisticRegressionCV
from sklearn.base import clone, BaseEstimator
from sklearn.preprocessing import FunctionTransformer
from sklearn.pipeline import Pipeline
from ..utilities import add_intercept, shape, check_inputs, _deprecate_positional
from ..grf import CausalForest, MultiOutputGRF
from .._cate_estimator import LinearCateEstimator
from .._shap import _shap_explain_multitask_model_cate
from .._ortho_learner import _OrthoLearner
from ..score import RScorer


class _CausalForestFinalWrapper(_FinalWrapper):
Expand Down Expand Up @@ -173,9 +176,9 @@ class CausalForestDML(_BaseDML):

sum_{child} E[(Y - <theta(child), T> - beta(child))^2 | X=child] weight(child)

Internally, for the case of more than two treatments or for the case of one treatment with
Internally, for the case of more than two treatments or for the case of two treatments with
``fit_intercept=True`` then this criterion is approximated by computationally simpler variants for
computationaly purposes. In particular, it is replaced by:
computational purposes. In particular, it is replaced by:

.. code-block::

Expand Down Expand Up @@ -394,7 +397,7 @@ def __init__(self, *,
min_samples_leaf=5,
min_weight_fraction_leaf=0.,
min_var_fraction_leaf=None,
min_var_leaf_on_val=True,
min_var_leaf_on_val=False,
max_features="auto",
min_impurity_decrease=0.,
max_samples=.45,
Expand Down Expand Up @@ -498,7 +501,116 @@ def _gen_model_final(self):
def _gen_rlearner_model_final(self):
return _CausalForestFinalWrapper(self._gen_model_final(), False, self._gen_featurizer(), False)

@property
def tunable_params(self):
return ['n_estimators', 'criterion', 'max_depth', 'min_samples_split', 'min_samples_leaf',
'min_weight_fraction_leaf', 'min_var_fraction_leaf', 'min_var_leaf_on_val',
'max_features', 'min_impurity_decrease', 'max_samples', 'min_balancedness_tol',
'honest', 'inference', 'fit_intercept', 'subforest_size']

def tune(self, Y, T, *, X=None, W=None,
sample_weight=None, sample_var=None, groups=None,
params='auto'):
"""
Tunes the major hyperparameters of the final stage causal forest based on out-of-sample R-score
performance. It trains small forests of size 100 trees on a grid of parameters and tests the
out of sample R-score. After the function is called, then all parameters of `self` have been
set to the optimal hyperparameters found. The estimator however remains un-fitted, so you need to
call fit afterwards to fit the estimator with the chosen hyperparameters. The list of tunable parameters
can be accessed via the property `tunable_params`.

Parameters
----------
Y: (n × d_y) matrix or vector of length n
Outcomes for each sample
T: (n × dₜ) matrix or vector of length n
Treatments for each sample
X: (n × dₓ) matrix
Features for each sample
W: optional (n × d_w) matrix
Controls for each sample
sample_weight: optional (n,) vector
Weights for each row
sample_var: optional (n, n_y) vector
Variance of sample, in case it corresponds to summary of many samples. Currently
not in use by this method (as inference method does not require sample variance info).
groups: (n,) vector, optional
All rows corresponding to the same group will be kept together during splitting.
If groups is not None, the `cv` argument passed to this class's initializer
must support a 'groups' argument to its split method.
params: dict or 'auto', optional (default='auto')
A dictionary that contains the grid of hyperparameters to try, i.e.
{'param1': [value1, value2, ...], 'param2': [value1, value2, ...], ...}
If `params='auto'`, then a default grid is used.

Returns
-------
self : CausalForestDML object
The tuned causal forest object. This is the same object (not a copy) as the original one, but where
all parameters of the object have been set to the best performing parameters from the tuning grid.
"""
if params == 'auto':
params = {'max_samples': [.3, .5],
'min_balancedness_tol': [.3, .5],
'min_samples_leaf': [5, 50],
'max_depth': [3, None],
'min_var_fraction_leaf': [None, .01]}
else:
# If custom param grid, check that only estimator parameters are being altered
estimator_param_names = self.tunable_params
for key in params.keys():
if key not in estimator_param_names:
raise ValueError(f"Parameter `{key}` is not an tunable causal forest parameter.")

strata = None
if self.discrete_treatment:
strata = self._strata(Y, T, X=X, W=W, sample_weight=sample_weight, groups=groups)
train, test = train_test_split(np.arange(Y.shape[0]), train_size=.7,
random_state=self.random_state, stratify=strata)
ytrain, yval, Ttrain, Tval = Y[train], Y[test], T[train], T[test]
Xtrain, Xval = (X[train], X[test]) if X is not None else (None, None)
Wtrain, Wval = (W[train], W[test]) if W is not None else (None, None)
groups_train, groups_val = (groups[train], groups[test]) if groups is not None else (None, None)
if sample_weight is not None:
sample_weight_train, sample_weight_val = sample_weight[train], sample_weight[test]
else:
sample_weight_train, sample_weight_val = None, None
if sample_var is not None:
sample_var_train, _ = sample_var[train], sample_var[test]
else:
sample_var_train, _ = None, None

est = clone(self, safe=False)
est.n_estimators = 100
est.inference = False

scorer = RScorer(model_y=est.model_y, model_t=est.model_t,
discrete_treatment=est.discrete_treatment, categories=est.categories,
cv=est.cv, mc_iters=est.mc_iters, mc_agg=est.mc_agg,
random_state=est.random_state)
scorer.fit(yval, Tval, X=Xval, W=Wval, sample_weight=sample_weight_val, groups=groups_val)

names = params.keys()
scores = []
for it, values in enumerate(product(*params.values())):
for key, value in zip(names, values):
setattr(est, key, value)
if it == 0:
est.fit(ytrain, Ttrain, X=Xtrain, W=Wtrain, sample_weight=sample_weight_train,
sample_var=sample_var_train, groups=groups_train, cache_values=True)
else:
est.refit_final()
scores.append((scorer.score(est), tuple(zip(names, values))))

bestind = np.argmax([s[0] for s in scores])
_, best_params = scores[bestind]
for key, value in best_params:
setattr(self, key, value)

return self

# override only so that we can update the docstring to indicate support for `blb`

@_deprecate_positional("X and W should be passed by keyword only. In a future release "
"we will disallow passing X and W by position.", ['X', 'W'])
def fit(self, Y, T, X=None, W=None, *, sample_weight=None, sample_var=None, groups=None,
Expand Down
4 changes: 2 additions & 2 deletions econml/grf/classes.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,9 +107,9 @@ class CausalForest(BaseGRF):

sum_{child} E[(Y - <theta(child), T> - beta(child))^2 | X=child] weight(child)

Internally, for the case of more than two treatments or for the case of one treatment with
Internally, for the case of more than two treatments or for the case of two treatments with
``fit_intercept=True`` then this criterion is approximated by computationally simpler variants for
computationaly purposes. In particular, it is replaced by::
computational purposes. In particular, it is replaced by::

sum_{child} weight(child) * rho(child).T @ E[(T;1) @ (T;1).T | X in child] @ rho(child)

Expand Down
6 changes: 5 additions & 1 deletion econml/tests/test_dml.py
Original file line number Diff line number Diff line change
Expand Up @@ -610,7 +610,7 @@ def true_fn(x):
y_sum = np.concatenate((y1_sum, y2_sum)) # outcome
n_sum = np.concatenate((n1_sum, n2_sum)) # number of summarized points
var_sum = np.concatenate((var1_sum, var2_sum)) # variance of the summarized points
for summarized, min_samples_leaf in [(False, 20), (True, 1)]:
for summarized, min_samples_leaf, tune in [(False, 20, False), (True, 1, False), (False, 20, True)]:
est = CausalForestDML(model_y=GradientBoostingRegressor(n_estimators=30, min_samples_leaf=30),
model_t=GradientBoostingClassifier(n_estimators=30, min_samples_leaf=30),
discrete_treatment=True,
Expand All @@ -626,6 +626,8 @@ def true_fn(x):
est.fit(y_sum, T_sum, X=X_sum[:, :4], W=X_sum[:, 4:],
sample_weight=n_sum)
else:
if tune:
est.tune(y, T, X=X[:, :4], W=X[:, 4:])
est.fit(y, T, X=X[:, :4], W=X[:, 4:])
X_test = np.array(list(itertools.product([0, 1], repeat=4)))
point = est.effect(X_test)
Expand All @@ -650,6 +652,8 @@ def true_fn(x):
est.fit(y_sum, T_sum, X=X_sum[:, :4], W=X_sum[:, 4:],
sample_weight=n_sum)
else:
if tune:
est.tune(y, T, X=X[:, :4], W=X[:, 4:], params={'max_samples': [.1, .3]})
est.fit(y, T, X=X[:, :4], W=X[:, 4:])
X_test = np.array(list(itertools.product([0, 1], repeat=4)))
point = est.effect(X_test)
Expand Down
265 changes: 137 additions & 128 deletions notebooks/Causal Forest and Orthogonal Random Forest Examples.ipynb

Large diffs are not rendered by default.

533 changes: 272 additions & 261 deletions notebooks/Double Machine Learning Examples.ipynb

Large diffs are not rendered by default.

175 changes: 114 additions & 61 deletions notebooks/ForestLearners Basic Example.ipynb

Large diffs are not rendered by default.