Skip to content

Commit

Permalink
Make causal analysis more robust to high variance
Browse files Browse the repository at this point in the history
  • Loading branch information
kbattocchi committed Aug 6, 2021
1 parent 7462b66 commit f0aa066
Showing 1 changed file with 55 additions and 22 deletions.
77 changes: 55 additions & 22 deletions econml/solutions/causal_analysis/_causal_analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,7 @@ def _get_data_causal_insights_keys():

def _first_stage_reg(X, y, *, automl=True, random_state=None, verbose=0):
if automl:
model = GridSearchCVList([make_pipeline(StandardScaler(), LassoCV(random_state=random_state)),
model = GridSearchCVList([LassoCV(random_state=random_state),
RandomForestRegressor(
n_estimators=100, random_state=random_state, min_samples_leaf=10),
lgb.LGBMRegressor(num_leaves=32, random_state=random_state)],
Expand All @@ -138,25 +138,30 @@ def _first_stage_reg(X, y, *, automl=True, random_state=None, verbose=0):
scoring='r2',
verbose=verbose)
best_est = model.fit(X, y).best_estimator_
if isinstance(best_est, Pipeline):
return make_pipeline(StandardScaler(), Lasso(alpha=best_est.steps[1][1].alpha_, random_state=random_state))
if isinstance(best_est, LassoCV):
return Lasso(alpha=best_est.alpha_, random_state=random_state)
else:
return best_est
else:
model = make_pipeline(StandardScaler(), LassoCV(cv=5, random_state=random_state)).fit(X, y)
return make_pipeline(StandardScaler(), Lasso(alpha=model.steps[1][1].alpha_, random_state=random_state))
model = LassoCV(cv=5, random_state=random_state).fit(X, y)
return Lasso(alpha=model.alpha_, random_state=random_state)


def _first_stage_clf(X, y, *, make_regressor=False, automl=True, min_count=None, random_state=None, verbose=0):
# use same Cs as would be used by default by LogisticRegressionCV
cs = np.logspace(-4, 4, 10)
if min_count is None:
min_count = _CAT_LIMIT # we have at least this many instances
if automl:
model = GridSearchCVList([make_pipeline(StandardScaler(), LogisticRegression(max_iter=1000,
random_state=random_state)),
# NOTE: we don't use LogisticRegressionCV inside the grid search because of the nested stratification
# which could affect how many times each distinct Y value needs to be present in the data

model = GridSearchCVList([LogisticRegression(max_iter=1000,
random_state=random_state),
RandomForestClassifier(n_estimators=100, min_samples_leaf=10,
random_state=random_state),
lgb.LGBMClassifier(num_leaves=32, random_state=random_state)],
param_grid_list=[{'logisticregression__C': [0.01, .1, 1, 10, 100]},
param_grid_list=[{'C': cs},
{'max_depth': [3, None],
'min_weight_fraction_leaf': [.001, .01, .1]},
{'learning_rate': [0.1, 0.3], 'max_depth': [3, 5]}],
Expand All @@ -165,10 +170,9 @@ def _first_stage_clf(X, y, *, make_regressor=False, automl=True, min_count=None,
verbose=verbose)
est = model.fit(X, y).best_estimator_
else:
model = make_pipeline(StandardScaler(), LogisticRegressionCV(
cv=min(5, min_count), max_iter=1000, random_state=random_state)).fit(X, y)
est = make_pipeline(StandardScaler(), LogisticRegression(
C=model.steps[1][1].C_[0], random_state=random_state))
model = LogisticRegressionCV(
cv=min(5, min_count), max_iter=1000, Cs=cs, random_state=random_state).fit(X, y)
est = LogisticRegression(C=model.C_[0], random_state=random_state)
if make_regressor:
return _RegressionWrapper(est)
else:
Expand Down Expand Up @@ -204,14 +208,21 @@ def fit(self, X):
handle_unknown='ignore').fit(cat_cols)
else:
self.has_cats = False
cont_cols = _safe_indexing(X, self.passthrough, axis=1)
if cont_cols.shape[1] > 0:
self.has_conts = True
self.scaler = StandardScaler().fit(cont_cols)
else:
self.has_conts = False
self.d_x = X.shape[1]
return self

def transform(self, X):
rest = _safe_indexing(X, self.passthrough, axis=1)
if self.has_conts:
rest = self.scaler.transform(rest)
if self.has_cats:
cats = self.one_hot_encoder.transform(
_safe_indexing(X, self.categorical, axis=1))
cats = self.one_hot_encoder.transform(_safe_indexing(X, self.categorical, axis=1))
return np.hstack((cats, rest))
else:
return rest
Expand Down Expand Up @@ -288,7 +299,7 @@ def __init__(self, tree_dictionary, policy_value, always_treat, control_name):


def _process_feature(name, feat_ind, verbose, categorical_inds, categories, heterogeneity_inds, min_counts, y, X,
nuisance_models, h_model, random_state, model_y):
nuisance_models, h_model, random_state, model_y, cv, mc_iters):
try:
if verbose > 0:
print(f"CausalAnalysis: Feature {name}")
Expand All @@ -304,13 +315,13 @@ def _process_feature(name, feat_ind, verbose, categorical_inds, categories, hete
[ind for ind in categorical_inds
if ind != feat_ind]),
('drop', 'drop', feat_ind)],
remainder='passthrough')
remainder=StandardScaler())
W_transformer = ColumnTransformer([('encode', OneHotEncoder(drop='first', sparse=False),
[ind for ind in categorical_inds
if ind != feat_ind and ind not in hinds]),
('drop', 'drop', hinds),
('drop_feat', 'drop', feat_ind)],
remainder='passthrough')
remainder=StandardScaler())
# Use _ColumnTransformer instead of ColumnTransformer so we can get feature names
X_transformer = _ColumnTransformer([ind for ind in categorical_inds
if ind != feat_ind and ind in hinds],
Expand Down Expand Up @@ -359,7 +370,9 @@ def _process_feature(name, feat_ind, verbose, categorical_inds, categories, hete
fit_cate_intercept=True,
linear_first_stages=False,
categories=cats,
random_state=random_state)
random_state=random_state,
cv=cv,
mc_iters=mc_iters)
elif h_model == 'forest':
est = CausalForestDML(model_y=model_y,
model_t=model_t,
Expand All @@ -368,7 +381,9 @@ def _process_feature(name, feat_ind, verbose, categorical_inds, categories, hete
min_var_leaf_on_val=True,
categories=cats,
random_state=random_state,
verbose=verbose)
verbose=verbose,
cv=cv,
mc_iters=mc_iters)

if verbose > 0:
print("CausalAnalysis: tuning forest")
Expand Down Expand Up @@ -494,6 +509,20 @@ class CausalAnalysis:
Degree of parallelism to use when training models via joblib.Parallel
verbose : int, default=0
Controls the verbosity when fitting and predicting.
cv: int, cross-validation generator or an iterable, default 5
Determines the strategy for cross-fitting used when training causal models for each feature.
Possible inputs for cv are:
- integer, to specify the number of folds.
- :term:`CV splitter`
- An iterable yielding (train, test) splits as arrays of indices.
For integer inputs, if the treatment is discrete
:class:`~sklearn.model_selection.StratifiedKFold` is used, else,
:class:`~sklearn.model_selection.KFold` is used
(with a random shuffle in either case).
mc_iters: int, default 3
The number of times to rerun the first stage models to reduce the variance of the causal model nuisances.
skip_cat_limit_checks: bool, default False
By default, categorical features need to have several instances of each category in order for a model to be
fit robustly. Setting this to True will skip these checks (although at least 2 instances will always be
Expand Down Expand Up @@ -525,7 +554,8 @@ class CausalAnalysis:

def __init__(self, feature_inds, categorical, heterogeneity_inds=None, feature_names=None, classification=False,
upper_bound_on_cat_expansion=5, nuisance_models='linear', heterogeneity_model='linear', *,
categories='auto', n_jobs=-1, verbose=0, skip_cat_limit_checks=False, random_state=None):
categories='auto', n_jobs=-1, verbose=0, cv=5, mc_iters=3, skip_cat_limit_checks=False,
random_state=None):
self.feature_inds = feature_inds
self.categorical = categorical
self.heterogeneity_inds = heterogeneity_inds
Expand All @@ -537,6 +567,8 @@ def __init__(self, feature_inds, categorical, heterogeneity_inds=None, feature_n
self.categories = categories
self.n_jobs = n_jobs
self.verbose = verbose
self.cv = cv
self.mc_iters = mc_iters
self.skip_cat_limit_checks = skip_cat_limit_checks
self.random_state = random_state

Expand Down Expand Up @@ -650,7 +682,7 @@ def fit(self, X, y, warm_start=False):
OneHotEncoder(
drop='first', sparse=False),
self.categorical)],
remainder='passthrough').fit_transform(X)
remainder=StandardScaler()).fit_transform(X)

if self.verbose > 0:
print("CausalAnalysis: performing model selection on overall Y model")
Expand Down Expand Up @@ -787,7 +819,8 @@ def fit(self, X, y, warm_start=False):
)(joblib.delayed(_process_feature)(
feat_name, feat_ind,
self.verbose, categorical_inds, categories, heterogeneity_inds, min_counts, y, X,
self.nuisance_models, self.heterogeneity_model, self.random_state, self._model_y)
self.nuisance_models, self.heterogeneity_model, self.random_state, self._model_y,
self.cv, self.mc_iters)
for feat_name, feat_ind in zip(new_feat_names, new_inds))))

# track indices where an exception was thrown, since we can't remove from dictionary while iterating
Expand Down

0 comments on commit f0aa066

Please sign in to comment.