Skip to content

Commit

Permalink
Mehei/otherinferences (#203)
Browse files Browse the repository at this point in the history
* add analytical effect/marginal effect/constant marginal effect inferences for DML and DRLearner
* add coefficient inference and intercept inference for linear final model
* add population summary inference given dataset X
  • Loading branch information
heimengqi authored Feb 15, 2020
1 parent e72422d commit f42251e
Show file tree
Hide file tree
Showing 11 changed files with 3,612 additions and 825 deletions.
10 changes: 10 additions & 0 deletions econml/_ortho_learner.py
Original file line number Diff line number Diff line change
Expand Up @@ -543,11 +543,21 @@ def const_marginal_effect_interval(self, X=None, *, alpha=0.1):
return super().const_marginal_effect_interval(X, alpha=alpha)
const_marginal_effect_interval.__doc__ = LinearCateEstimator.const_marginal_effect_interval.__doc__

def const_marginal_effect_inference(self, X=None):
self._check_fitted_dims(X)
return super().const_marginal_effect_inference(X)
const_marginal_effect_inference.__doc__ = LinearCateEstimator.const_marginal_effect_inference.__doc__

def effect_interval(self, X=None, *, T0=0, T1=1, alpha=0.1):
self._check_fitted_dims(X)
return super().effect_interval(X, T0=T0, T1=T1, alpha=alpha)
effect_interval.__doc__ = LinearCateEstimator.effect_interval.__doc__

def effect_inference(self, X=None, *, T0=0, T1=1):
self._check_fitted_dims(X)
return super().effect_inference(X, T0=T0, T1=T1)
effect_inference.__doc__ = LinearCateEstimator.effect_inference.__doc__

def score(self, Y, T, X=None, W=None, Z=None):
"""
Score the fitted CATE model on a new data set. Generates nuisance parameters
Expand Down
189 changes: 188 additions & 1 deletion econml/cate_estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from .inference import BootstrapInference
from .utilities import tensordot, ndim, reshape, shape, parse_final_model_params, inverse_onehot
from .inference import StatsModelsInference, StatsModelsInferenceDiscrete, LinearModelFinalInference,\
LinearModelFinalInferenceDiscrete
LinearModelFinalInferenceDiscrete, InferenceResults


class BaseCateEstimator(metaclass=abc.ABCMeta):
Expand Down Expand Up @@ -197,6 +197,30 @@ def effect_interval(self, X=None, *, T0=0, T1=1, alpha=0.1):
"""
pass

@_defer_to_inference
def effect_inference(self, X=None, *, T0=0, T1=1):
""" Inference results for the quantities :math:`\\tau(X, T0, T1)` produced
by the model. Available only when ``inference`` is not ``None``, when
calling the fit method.
Parameters
----------
X: optional (m, d_x) matrix
Features for each sample
T0: optional (m, d_t) matrix or vector of length m (Default=0)
Base treatments for each sample
T1: optional (m, d_t) matrix or vector of length m (Default=1)
Target treatments for each sample
Returns
-------
InferenceResults: object
The inference results instance contains prediction and prediction standard error and
can on demand calculate confidence interval, z statistic and p value. It can also output
a dataframe summary of these inference results.
"""
pass

@_defer_to_inference
def marginal_effect_interval(self, T, X=None, *, alpha=0.1):
""" Confidence intervals for the quantities :math:`\\partial \\tau(T, X)` produced
Expand All @@ -221,6 +245,28 @@ def marginal_effect_interval(self, T, X=None, *, alpha=0.1):
"""
pass

@_defer_to_inference
def marginal_effect_inference(self, T, X=None):
""" Inference results for the quantities :math:`\\partial \\tau(T, X)` produced
by the model. Available only when ``inference`` is not ``None``, when
calling the fit method.
Parameters
----------
T: (m, d_t) matrix
Base treatments for each sample
X: optional (m, d_x) matrix or None (Default=None)
Features for each sample
Returns
-------
InferenceResults: object
The inference results instance contains prediction and prediction standard error and
can on demand calculate confidence interval, z statistic and p value. It can also output
a dataframe summary of these inference results.
"""
pass


class LinearCateEstimator(BaseCateEstimator):
"""Base class for all CATE estimators with linear treatment effects in this package."""
Expand Down Expand Up @@ -324,6 +370,18 @@ def marginal_effect_interval(self, T, X=None, *, alpha=0.1):
for eff in effs)
marginal_effect_interval.__doc__ = BaseCateEstimator.marginal_effect_interval.__doc__

def marginal_effect_inference(self, T, X=None):
X, T = self._expand_treatments(X, T)
cme_inf = self.const_marginal_effect_inference(X=X)
pred = cme_inf.point_estimate
pred_stderr = cme_inf.stderr
if X is None:
pred = np.repeat(pred, shape(T)[0], axis=0)
pred_stderr = np.repeat(pred_stderr, shape(T)[0], axis=0)
return InferenceResults(d_t=cme_inf.d_t, d_y=cme_inf.d_y, pred=pred,
pred_stderr=pred_stderr, inf_type='effect', pred_dist=None, fname_transformer=None)
marginal_effect_inference.__doc__ = BaseCateEstimator.marginal_effect_inference.__doc__

@BaseCateEstimator._defer_to_inference
def const_marginal_effect_interval(self, X=None, *, alpha=0.1):
""" Confidence intervals for the quantities :math:`\\theta(X)` produced
Expand All @@ -346,6 +404,26 @@ def const_marginal_effect_interval(self, X=None, *, alpha=0.1):
"""
pass

@BaseCateEstimator._defer_to_inference
def const_marginal_effect_inference(self, X=None):
""" Inference results for the quantities :math:`\\theta(X)` produced
by the model. Available only when ``inference`` is not ``None``, when
calling the fit method.
Parameters
----------
X: optional (m, d_x) matrix or None (Default=None)
Features for each sample
Returns
-------
InferenceResults: object
The inference results instance contains prediction and prediction standard error and
can on demand calculate confidence interval, z statistic and p value. It can also output
a dataframe summary of these inference results.
"""
pass


class TreatmentExpansionMixin(BaseCateEstimator):
"""Mixin which automatically handles promotions of scalar treatments to the appropriate shape."""
Expand Down Expand Up @@ -454,6 +532,18 @@ def coef__interval(self, *, alpha=0.1):
"""
pass

@BaseCateEstimator._defer_to_inference
def coef__inference(self):
""" The inference of coefficients in the linear model of the constant marginal treatment
effect.
Returns
-------
InferenceResults: object
The inference of the coefficients in the final linear model
"""
pass

@BaseCateEstimator._defer_to_inference
def intercept__interval(self, *, alpha=0.1):
""" The intercept in the linear model of the constant marginal treatment
Expand All @@ -472,6 +562,43 @@ def intercept__interval(self, *, alpha=0.1):
"""
pass

@BaseCateEstimator._defer_to_inference
def intercept__inference(self):
""" The inference of intercept in the linear model of the constant marginal treatment
effect.
Returns
-------
InferenceResults: object
The inference of the intercept in the final linear model
"""
pass

@BaseCateEstimator._defer_to_inference
def summary(self, alpha=0.1, value=0, decimals=3, feat_name=None):
""" The summary of coefficient and intercept in the linear model of the constant marginal treatment
effect.
Parameters
----------
alpha: optional float in [0, 1] (default=0.1)
The overall level of confidence of the reported interval.
The alpha/2, 1-alpha/2 confidence interval is reported.
value: optinal float (default=0)
The mean value of the metric you'd like to test under null hypothesis.
decimals: optinal int (default=3)
Number of decimal places to round each column to.
feat_name: optional list of strings or None (default is None)
The input of the feature names
Returns
-------
smry : Summary instance
this holds the summary tables and text, which can be printed or
converted to various output formats.
"""
pass


class StatsModelsCateEstimatorMixin(LinearModelFinalCateEstimatorMixin):
"""
Expand Down Expand Up @@ -568,6 +695,23 @@ def coef__interval(self, T, *, alpha=0.1):
"""
pass

@BaseCateEstimator._defer_to_inference
def coef__inference(self, T):
""" The inference for the coefficients in the linear model of the
constant marginal treatment effect associated with treatment T.
Parameters
----------
T: alphanumeric
The input treatment for which we want the coefficients.
Returns
-------
InferenceResults: object
The inference of the coefficients in the final linear model
"""
pass

@BaseCateEstimator._defer_to_inference
def intercept__interval(self, T, *, alpha=0.1):
""" The intercept in the linear model of the constant marginal treatment
Expand All @@ -588,6 +732,49 @@ def intercept__interval(self, T, *, alpha=0.1):
"""
pass

@BaseCateEstimator._defer_to_inference
def intercept__inference(self, T):
""" The inference of the intercept in the linear model of the constant marginal treatment
effect associated with treatment T.
Parameters
----------
T: alphanumeric
The input treatment for which we want the coefficients.
Returns
-------
InferenceResults: object
The inference of the intercept in the final linear model
"""
pass

@BaseCateEstimator._defer_to_inference
def summary(self, T, *, alpha=0.1, value=0, decimals=3, feat_name=None):
""" The summary of coefficient and intercept in the linear model of the constant marginal treatment
effect associated with treatment T.
Parameters
----------
alpha: optional float in [0, 1] (default=0.1)
The overall level of confidence of the reported interval.
The alpha/2, 1-alpha/2 confidence interval is reported.
value: optinal float (default=0)
The mean value of the metric you'd like to test under null hypothesis.
decimals: optinal int (default=3)
Number of decimal places to round each column to.
feat_name: optional list of strings or None (default is None)
The input of the feature names
Returns
-------
smry : Summary instance
this holds the summary tables and text, which can be printed or
converted to various output formats.
"""
pass


class StatsModelsCateEstimatorDiscreteMixin(LinearModelFinalCateEstimatorDiscreteMixin):
"""
Expand Down
2 changes: 2 additions & 0 deletions econml/drlearner.py
Original file line number Diff line number Diff line change
Expand Up @@ -638,6 +638,7 @@ def __init__(self,
fit_cate_intercept=True,
min_propensity=1e-6,
n_splits=2, random_state=None):
self.fit_cate_intercept = fit_cate_intercept
super().__init__(model_propensity=model_propensity,
model_regression=model_regression,
model_final=StatsModelsLinearRegression(fit_intercept=fit_cate_intercept),
Expand Down Expand Up @@ -837,6 +838,7 @@ def __init__(self,
tol=1e-4,
min_propensity=1e-6,
n_splits=2, random_state=None):
self.fit_cate_intercept = fit_cate_intercept
model_final = DebiasedLasso(
alpha=alpha,
fit_intercept=fit_cate_intercept,
Expand Down
Loading

0 comments on commit f42251e

Please sign in to comment.