Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Hyperparameter tuning for CausalForestDML #390

Merged
merged 18 commits into from
Mar 9, 2021
Merged
Show file tree
Hide file tree
Changes from 16 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 30 additions & 0 deletions econml/_ortho_learner.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,9 @@ class in this module implements the general logic in a very versatile way
from collections import namedtuple
from warnings import warn
from abc import abstractmethod
import inspect
from collections import defaultdict
import re
vsyrgkanis marked this conversation as resolved.
Show resolved Hide resolved

import numpy as np
from sklearn.base import clone
Expand Down Expand Up @@ -436,6 +439,33 @@ def __init__(self, *,
self.mc_agg = mc_agg
super().__init__()

@classmethod
def _get_param_names(cls):
"""Get parameter names for the estimator"""
# fetch the constructor or the original constructor before
# deprecation wrapping if any
vsyrgkanis marked this conversation as resolved.
Show resolved Hide resolved
init = cls.__init__
if init is object.__init__:
# No explicit constructor to introspect
return []

# introspect the constructor arguments to find the model parameters
# to represent
init_signature = inspect.signature(init)
# Consider the constructor parameters excluding 'self'
parameters = [p for p in init_signature.parameters.values()
if p.name != 'self' and p.kind != p.VAR_KEYWORD]
for p in parameters:
if p.kind == p.VAR_POSITIONAL:
raise RuntimeError("ortho learner cate estimators should always "
"specify their parameters in the signature"
" of their __init__ (no varargs)."
" %s with constructor %s doesn't "
" follow this convention."
% (cls, init_signature))
vsyrgkanis marked this conversation as resolved.
Show resolved Hide resolved
# Extract and sort argument names excluding 'self'
return sorted([p.name for p in parameters])

@abstractmethod
def _gen_ortho_learner_model_nuisance(self):
""" Must return a fresh instance of a nuisance model
Expand Down
111 changes: 105 additions & 6 deletions econml/dml/causal_forest.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,21 +4,24 @@
from warnings import warn

import numpy as np
from sklearn.linear_model import LogisticRegressionCV
from sklearn.base import clone, BaseEstimator
from sklearn.preprocessing import FunctionTransformer
from sklearn.pipeline import Pipeline
from sklearn.model_selection import train_test_split
from itertools import product
from .dml import _BaseDML
from .dml import _FirstStageWrapper, _FinalWrapper
from ..sklearn_extensions.linear_model import WeightedLassoCVWrapper
from ..sklearn_extensions.model_selection import WeightedStratifiedKFold
from ..inference import NormalInferenceResults
from ..inference._inference import Inference
from sklearn.linear_model import LogisticRegressionCV
from sklearn.base import clone, BaseEstimator
from sklearn.preprocessing import FunctionTransformer
from sklearn.pipeline import Pipeline
from ..utilities import add_intercept, shape, check_inputs, _deprecate_positional
from ..grf import CausalForest, MultiOutputGRF
from .._cate_estimator import LinearCateEstimator
from .._shap import _shap_explain_multitask_model_cate
from .._ortho_learner import _OrthoLearner
from ..score import RScorer


class _CausalForestFinalWrapper(_FinalWrapper):
Expand Down Expand Up @@ -173,7 +176,7 @@ class CausalForestDML(_BaseDML):

sum_{child} E[(Y - <theta(child), T> - beta(child))^2 | X=child] weight(child)

Internally, for the case of more than two treatments or for the case of one treatment with
Internally, for the case of more than two treatments or for the case of two treatments with
``fit_intercept=True`` then this criterion is approximated by computationally simpler variants for
computationaly purposes. In particular, it is replaced by:
vsyrgkanis marked this conversation as resolved.
Show resolved Hide resolved

Expand Down Expand Up @@ -394,7 +397,7 @@ def __init__(self, *,
min_samples_leaf=5,
min_weight_fraction_leaf=0.,
min_var_fraction_leaf=None,
min_var_leaf_on_val=True,
min_var_leaf_on_val=False,
max_features="auto",
min_impurity_decrease=0.,
max_samples=.45,
Expand Down Expand Up @@ -498,7 +501,103 @@ def _gen_model_final(self):
def _gen_rlearner_model_final(self):
return _CausalForestFinalWrapper(self._gen_model_final(), False, self._gen_featurizer(), False)

def tune(self, Y, T, *, X=None, W=None,
sample_weight=None, sample_var=None, groups=None,
params='auto'):
"""
Tunes the major hyperparameters of the final stage causal forest based on out-of-sample R-score
performance. It trains small forests of size 100 trees on a grid of parameters and tests the
out of sample R-score. After the function is called, then all parameters of `self` have been
set to the optimal hyperparameters found.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe explicitly mention that although the parameters will have been updated, this estimator will not have been fit with this data.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

now only final stage params are tunable

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I still think it's worth mentioning that this estimator has not been fit with this data, so that the user knows that the they still have to call fit afterward.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Or alternatively, in your notebooks I think you always use:

est.tune(...)
est.fit(...)

Are there times a user would not want to do that, or should calling fit just be folded into tune so that it saves the user the trouble?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's what I'm doing in the notebooks I think.

I think it's good to have separate. Maybe someone wants to tune on a subset of the data or some small chunk. Also the tune does not need to take other keyword args like inference, cache_values etc.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I added the extra docstring comment that the returned self is not yet fitted.


Parameters
----------
Y: (n × d_y) matrix or vector of length n
Outcomes for each sample
T: (n × dₜ) matrix or vector of length n
Treatments for each sample
X: (n × dₓ) matrix
Features for each sample
W: optional (n × d_w) matrix
Controls for each sample
sample_weight: optional (n,) vector
Weights for each row
sample_var: optional (n, n_y) vector
Variance of sample, in case it corresponds to summary of many samples. Currently
not in use by this method (as inference method does not require sample variance info).
groups: (n,) vector, optional
All rows corresponding to the same group will be kept together during splitting.
If groups is not None, the `cv` argument passed to this class's initializer
must support a 'groups' argument to its split method.
params: dict or 'auto', optional (default='auto')
A dictionary that contains the grid of hyperparameters to try, i.e.
{'param1': [value1, value2, ...], 'param2': [value1, value2, ...], ...}
If `params='auto'`, then a default grid is used.

Returns
-------
self : CausalForestDML object
The tuned causal forest object. This is the same object (not a copy) as the original one, but where
all parameters of the object have been set to the best performing parameters from the tuning grid.
"""
if params == 'auto':
params = {'max_samples': [.3, .5],
'min_balancedness_tol': [.3, .5],
'min_samples_leaf': [5, 50],
'max_depth': [3, None],
'min_var_fraction_leaf': [None, .01]}
else:
# If custom param grid, check that only estimator parameters are being altered
estimator_param_names = self._get_param_names()
vsyrgkanis marked this conversation as resolved.
Show resolved Hide resolved
for key in params.keys():
if key not in estimator_param_names:
raise ValueError("Parameter `{}` is not an estimator parameter.".format(key))
vsyrgkanis marked this conversation as resolved.
Show resolved Hide resolved

train, test = train_test_split(np.arange(Y.shape[0]), train_size=.7, random_state=self.random_state)
kbattocchi marked this conversation as resolved.
Show resolved Hide resolved
ytrain, yval, Ttrain, Tval = Y[train], Y[test], T[train], T[test]
Xtrain, Xval = (X[train], X[test]) if X is not None else (None, None)
Wtrain, Wval = (W[train], W[test]) if W is not None else (None, None)
groups_train, groups_val = (groups[train], groups[test]) if groups is not None else (None, None)
if sample_weight is not None:
sample_weight_train, sample_weight_val = sample_weight[train], sample_weight[test]
else:
sample_weight_train, sample_weight_val = None, None
if sample_var is not None:
sample_var_train, _ = sample_var[train], sample_var[test]
else:
sample_var_train, _ = None, None

est = clone(self, safe=False)
est.n_estimators = 100
est.inference = False

scorer = RScorer(model_y=est.model_y, model_t=est.model_t,
discrete_treatment=est.discrete_treatment, categories=est.categories,
cv=est.cv, mc_iters=est.mc_iters, mc_agg=est.mc_agg,
random_state=est.random_state)
scorer.fit(yval, Tval, X=Xval, W=Wval, sample_weight=sample_weight_val, groups=groups_val)

names = params.keys()
scores = []
for it, values in enumerate(product(*params.values())):
for key, value in zip(names, values):
setattr(est, key, value)
if it == 0:
est.fit(ytrain, Ttrain, X=Xtrain, W=Wtrain, sample_weight=sample_weight_train,
sample_var=sample_var_train, groups=groups_train, cache_values=True)
else:
est.refit_final()
scores.append((scorer.score(est), tuple(zip(names, values))))

bestind = np.argmax([s[0] for s in scores])
_, best_params = scores[bestind]
for key, value in best_params:
setattr(self, key, value)

return self

# override only so that we can update the docstring to indicate support for `blb`

@_deprecate_positional("X and W should be passed by keyword only. In a future release "
"we will disallow passing X and W by position.", ['X', 'W'])
def fit(self, Y, T, X=None, W=None, *, sample_weight=None, sample_var=None, groups=None,
Expand Down
2 changes: 1 addition & 1 deletion econml/grf/classes.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ class CausalForest(BaseGRF):

sum_{child} E[(Y - <theta(child), T> - beta(child))^2 | X=child] weight(child)

Internally, for the case of more than two treatments or for the case of one treatment with
Internally, for the case of more than two treatments or for the case of two treatments with
``fit_intercept=True`` then this criterion is approximated by computationally simpler variants for
computationaly purposes. In particular, it is replaced by::
vsyrgkanis marked this conversation as resolved.
Show resolved Hide resolved

Expand Down
6 changes: 5 additions & 1 deletion econml/tests/test_dml.py
Original file line number Diff line number Diff line change
Expand Up @@ -610,7 +610,7 @@ def true_fn(x):
y_sum = np.concatenate((y1_sum, y2_sum)) # outcome
n_sum = np.concatenate((n1_sum, n2_sum)) # number of summarized points
var_sum = np.concatenate((var1_sum, var2_sum)) # variance of the summarized points
for summarized, min_samples_leaf in [(False, 20), (True, 1)]:
for summarized, min_samples_leaf, tune in [(False, 20, False), (True, 1, False), (False, 20, True)]:
est = CausalForestDML(model_y=GradientBoostingRegressor(n_estimators=30, min_samples_leaf=30),
model_t=GradientBoostingClassifier(n_estimators=30, min_samples_leaf=30),
discrete_treatment=True,
Expand All @@ -626,6 +626,8 @@ def true_fn(x):
est.fit(y_sum, T_sum, X=X_sum[:, :4], W=X_sum[:, 4:],
sample_weight=n_sum)
else:
if tune:
est.tune(y, T, X=X[:, :4], W=X[:, 4:])
est.fit(y, T, X=X[:, :4], W=X[:, 4:])
X_test = np.array(list(itertools.product([0, 1], repeat=4)))
point = est.effect(X_test)
Expand All @@ -650,6 +652,8 @@ def true_fn(x):
est.fit(y_sum, T_sum, X=X_sum[:, :4], W=X_sum[:, 4:],
sample_weight=n_sum)
else:
if tune:
est.tune(y, T, X=X[:, :4], W=X[:, 4:], params={'max_samples': [.1, .3]})
est.fit(y, T, X=X[:, :4], W=X[:, 4:])
X_test = np.array(list(itertools.product([0, 1], repeat=4)))
point = est.effect(X_test)
Expand Down
265 changes: 137 additions & 128 deletions notebooks/Causal Forest and Orthogonal Random Forest Examples.ipynb

Large diffs are not rendered by default.

533 changes: 272 additions & 261 deletions notebooks/Double Machine Learning Examples.ipynb

Large diffs are not rendered by default.

177 changes: 116 additions & 61 deletions notebooks/ForestLearners Basic Example.ipynb

Large diffs are not rendered by default.