Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Mehei/otherinferences #203

Merged
merged 28 commits into from
Feb 15, 2020
Merged
Show file tree
Hide file tree
Changes from 10 commits
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
613947b
add other analytical inferences for const_marignal_effect
heimengqi Dec 3, 2019
1811f3b
add effect inference and population summary of inference
heimengqi Dec 17, 2019
20cf21b
add docstring
heimengqi Dec 17, 2019
4dfc5d2
Update setup.cfg to avoid Linux segfault issue
kbattocchi Dec 18, 2019
dfa0b3d
linting error
heimengqi Dec 18, 2019
4a6db8f
Merge branch 'mehei/otherinferences' of https://github.com/microsoft/…
heimengqi Dec 18, 2019
9be90a1
Panel() function has been deprecated, change to alternatives
heimengqi Dec 19, 2019
506aa91
fix debiased lasso prediction shape when y is a vector
heimengqi Dec 19, 2019
dec5424
improve population summary, update notebook and test
heimengqi Dec 26, 2019
3075c57
Merge branch 'master' into mehei/otherinferences
heimengqi Dec 26, 2019
01f5be5
linting error
heimengqi Dec 27, 2019
eee32da
Merge branch 'mehei/otherinferences' of https://github.com/microsoft/…
heimengqi Dec 27, 2019
e8da90d
syntax error
heimengqi Dec 27, 2019
7bd9bbb
update docstring
heimengqi Dec 27, 2019
6040702
support inferences for drlearner, update notebook and add test, chang…
heimengqi Jan 2, 2020
359eeba
Add coef__inference and intercept__inference, improvement on output, …
heimengqi Jan 7, 2020
6baa4b2
solve review comment
heimengqi Jan 9, 2020
ed81d96
Merge branch 'master' into mehei/otherinferences
heimengqi Jan 9, 2020
b7bb95a
add test to check whether the CI from inference class equals to CI fr…
heimengqi Jan 9, 2020
253c255
delete the test notebook
heimengqi Feb 5, 2020
40eba63
check fit_cate_intercept for drlearner and throw an error for effect_…
heimengqi Feb 5, 2020
3b22021
Merge branch 'mehei/otherinferences' of https://github.com/microsoft/…
heimengqi Feb 5, 2020
a794c12
Merge branch 'master' into mehei/otherinferences
heimengqi Feb 6, 2020
83552c9
fix linting error and drlearner inference support 2d response array
heimengqi Feb 6, 2020
c034308
add summary function for coef and intercept, fix shape inconsistence …
heimengqi Feb 11, 2020
dde15ce
solve conflict
heimengqi Feb 11, 2020
f3eb706
solve review comment
heimengqi Feb 14, 2020
101140c
Merge branch 'master' into mehei/otherinferences
heimengqi Feb 14, 2020
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9,395 changes: 0 additions & 9,395 deletions Other inferences test.ipynb

This file was deleted.

28 changes: 27 additions & 1 deletion econml/cate_estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -379,7 +379,7 @@ def marginal_effect_inference(self, T, X=None):
pred = np.repeat(pred, shape(T)[0], axis=0)
pred_stderr = np.repeat(pred_stderr, shape(T)[0], axis=0)
return InferenceResults(d_t=cme_inf.d_t, d_y=cme_inf.d_y, pred=pred,
pred_stderr=pred_stderr, inf_type='effect', pred_dist=None, fn_transformer=None)
pred_stderr=pred_stderr, inf_type='effect', pred_dist=None, fname_transformer=None)
marginal_effect_inference.__doc__ = BaseCateEstimator.marginal_effect_inference.__doc__

@BaseCateEstimator._defer_to_inference
Expand Down Expand Up @@ -574,6 +574,19 @@ def intercept__inference(self):
"""
pass

@BaseCateEstimator._defer_to_inference
def summary(self, alpha=0.1, value=0, decimals=3, feat_name=None):
""" The summary of coefficient and intercept in the linear model of the constant marginal treatment
effect.

Returns
heimengqi marked this conversation as resolved.
Show resolved Hide resolved
-------
smry : Summary instance
this holds the summary tables and text, which can be printed or
converted to various output formats.
"""
pass


class StatsModelsCateEstimatorMixin(LinearModelFinalCateEstimatorMixin):
"""
Expand Down Expand Up @@ -725,6 +738,19 @@ def intercept__inference(self, T):
"""
pass

@BaseCateEstimator._defer_to_inference
def summary(self, T, *, alpha=0.1, value=0, decimals=3, feat_name=None):
""" The summary of coefficient and intercept in the linear model of the constant marginal treatment
effect associated with treatment T.

Returns
heimengqi marked this conversation as resolved.
Show resolved Hide resolved
-------
smry : Summary instance
this holds the summary tables and text, which can be printed or
converted to various output formats.
"""
pass


class StatsModelsCateEstimatorDiscreteMixin(LinearModelFinalCateEstimatorDiscreteMixin):
"""
Expand Down
2 changes: 2 additions & 0 deletions econml/drlearner.py
Original file line number Diff line number Diff line change
Expand Up @@ -638,6 +638,7 @@ def __init__(self,
fit_cate_intercept=True,
min_propensity=1e-6,
n_splits=2, random_state=None):
self.fit_cate_intercept = fit_cate_intercept
super().__init__(model_propensity=model_propensity,
model_regression=model_regression,
model_final=StatsModelsLinearRegression(fit_intercept=fit_cate_intercept),
Expand Down Expand Up @@ -837,6 +838,7 @@ def __init__(self,
tol=1e-4,
min_propensity=1e-6,
n_splits=2, random_state=None):
self.fit_cate_intercept = fit_cate_intercept
model_final = DebiasedLasso(
alpha=alpha,
fit_intercept=fit_cate_intercept,
Expand Down
175 changes: 128 additions & 47 deletions econml/inference.py

Large diffs are not rendered by default.

51 changes: 38 additions & 13 deletions econml/tests/test_dml.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ def make_random(is_discrete, d):
d_t_final = 2 if is_discrete else d_t

effect_shape = (n,) + ((d_y,) if d_y > 0 else ())
effect_summaryframe_shape = (n * (d_y if d_y > 0 else 1),) + (6,)
effect_summaryframe_shape = (n * (d_y if d_y > 0 else 1), 6)
marginal_effect_shape = ((n,) +
((d_y,) if d_y > 0 else ()) +
((d_t_final,) if d_t_final > 0 else ()))
Expand Down Expand Up @@ -202,6 +202,9 @@ def make_random(is_discrete, d):
const_marginal_effect_shape)
self.assertEqual(shape(const_marg_effect_inf.conf_int()),
(2,) + const_marginal_effect_shape)
np.testing.assert_array_almost_equal(const_marg_effect_inf.conf_int()
[0], const_marg_eff_int[0],
heimengqi marked this conversation as resolved.
Show resolved Hide resolved
decimal=5)
const_marg_effect_inf.population_summary()._repr_html_()

# test effect inference
Expand All @@ -219,6 +222,9 @@ def make_random(is_discrete, d):
effect_shape)
self.assertEqual(shape(effect_inf.conf_int()),
(2,) + effect_shape)
np.testing.assert_array_almost_equal(effect_inf.conf_int()
[0], est.effect_interval
heimengqi marked this conversation as resolved.
Show resolved Hide resolved
(X, T0=T0, T1=T1)[0], decimal=5)
effect_inf.population_summary()._repr_html_()

# test marginal effect inference
Expand All @@ -236,31 +242,42 @@ def make_random(is_discrete, d):
marginal_effect_shape)
self.assertEqual(shape(marg_effect_inf.conf_int()),
(2,) + marginal_effect_shape)
np.testing.assert_array_almost_equal(marg_effect_inf.conf_int()
[0], marg_eff_int[0], decimal=5)
heimengqi marked this conversation as resolved.
Show resolved Hide resolved
marg_effect_inf.population_summary()._repr_html_()

# test coef__inference and intercept__inference
if (isinstance(est,
LinearDMLCateEstimator) or
isinstance(est,
SparseLinearDMLCateEstimator)):
if X is not None:
if X is None:
cm = pytest.raises(AttributeError)
else:
cm = ExitStack()
# ExitStack can be used as a "do nothing" ContextManager
with cm:
self.assertEqual(
shape(est.coef__inference().summary_frame()),
coef_summaryframe_shape)
else:
with pytest.raises(AttributeError):
self.assertEqual(
shape(est.coef__inference().summary_frame()),
coef_summaryframe_shape)
np.testing.assert_array_almost_equal(
est.coef__inference().conf_int()
[0], est.coef__interval()[0], decimal=5)

if fit_cate_intercept:
cm = ExitStack()
# ExitStack can be used as a "do nothing" ContextManager
else:
cm = pytest.raises(AttributeError)
with cm:
self.assertEqual(shape(est.intercept__inference().
summary_frame()),
intercept_summaryframe_shape)
else:
with pytest.raises(AttributeError):
self.assertEqual(shape(est.intercept__inference().
summary_frame()),
intercept_summaryframe_shape)
np.testing.assert_array_almost_equal(
est.intercept__inference().conf_int()
[0], est.intercept__interval()[0], decimal=5)

est.summary()

est.score(Y, T, X, W)

Expand Down Expand Up @@ -310,7 +327,7 @@ def make_random(is_discrete, d):
d_t_final = 1 if is_discrete else d_t

effect_shape = (n,) + ((d_y,) if d_y > 0 else ())
effect_summaryframe_shape = (n * (d_y if d_y > 0 else 1),) + (6,)
effect_summaryframe_shape = (n * (d_y if d_y > 0 else 1), 6)
marginal_effect_shape = ((n,) +
((d_y,) if d_y > 0 else ()) +
((d_t_final,) if d_t_final > 0 else ()))
Expand Down Expand Up @@ -404,6 +421,9 @@ def make_random(is_discrete, d):
const_marginal_effect_shape)
self.assertEqual(shape(const_marg_effect_inf.conf_int()),
(2,) + const_marginal_effect_shape)
np.testing.assert_array_almost_equal(const_marg_effect_inf.conf_int()
[0], const_marg_eff_int[0],
heimengqi marked this conversation as resolved.
Show resolved Hide resolved
decimal=5)
const_marg_effect_inf.population_summary()._repr_html_()

# test effect inference
Expand All @@ -421,6 +441,9 @@ def make_random(is_discrete, d):
effect_shape)
self.assertEqual(shape(effect_inf.conf_int()),
(2,) + effect_shape)
np.testing.assert_array_almost_equal(effect_inf.conf_int()
[0], est.effect_interval
heimengqi marked this conversation as resolved.
Show resolved Hide resolved
(X, T0=T0, T1=T1)[0], decimal=5)
effect_inf.population_summary()._repr_html_()

# test marginal effect inference
Expand All @@ -438,6 +461,8 @@ def make_random(is_discrete, d):
marginal_effect_shape)
self.assertEqual(shape(marg_effect_inf.conf_int()),
(2,) + marginal_effect_shape)
np.testing.assert_array_almost_equal(marg_effect_inf.conf_int()
[0], marg_eff_int[0], decimal=5)
heimengqi marked this conversation as resolved.
Show resolved Hide resolved
marg_effect_inf.population_summary()._repr_html_()

est.score(Y, T, X, W)
Expand Down
Loading