Skip to content

Commit

Permalink
Correct default value of alphas (#476)
Browse files Browse the repository at this point in the history
The PopulationSummaryResults class should to default to self.alpha when the alpha parameter is not passed. This is not the case currently since some methods use the statement 'alpha = self.alpha if alpha is None else alpha', however alpha's default value in the methods is .1, thus self.alpha will be ignored even if the client does not pass the alpha parameter.

Without the fix ate_interval(), marginal_ate_interval() and const_marginal_ate_interval() all ignore the alpha parameter.
  • Loading branch information
mtanghu authored Aug 9, 2021
1 parent 1bb4f3f commit 2c140f2
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 10 deletions.
21 changes: 11 additions & 10 deletions econml/inference/_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -1179,8 +1179,8 @@ class PopulationSummaryResults:
"""

def __init__(self, pred, pred_stderr, mean_pred_stderr, d_t, d_y, alpha, value, decimals, tol,
output_names=None, treatment_names=None):
def __init__(self, pred, pred_stderr, mean_pred_stderr, d_t, d_y, alpha=0.1,
value=0, decimals=3, tol=0.001, output_names=None, treatment_names=None):
self.pred = pred
self.pred_stderr = pred_stderr
self.mean_pred_stderr = mean_pred_stderr
Expand Down Expand Up @@ -1237,13 +1237,13 @@ def stderr_mean(self):
raise AttributeError("Only point estimates are available!")
return np.sqrt(np.mean(self.pred_stderr**2, axis=0))

def zstat(self, *, value=0):
def zstat(self, *, value=None):
"""
Get the z statistic of the mean point estimate of each treatment on each outcome for sample X.
Parameters
----------
value: optinal float (default=0)
value: optional float (default=0)
The mean value of the metric you'd like to test under null hypothesis.
Returns
Expand All @@ -1258,13 +1258,13 @@ def zstat(self, *, value=0):
zstat = (self.mean_point - value) / self.stderr_mean
return zstat

def pvalue(self, *, value=0):
def pvalue(self, *, value=None):
"""
Get the p value of the z test of each treatment on each outcome for sample X.
Parameters
----------
value: optinal float (default=0)
value: optional float (default=0)
The mean value of the metric you'd like to test under null hypothesis.
Returns
Expand All @@ -1275,10 +1275,11 @@ def pvalue(self, *, value=0):
the corresponding singleton dimensions in the output will be collapsed
(e.g. if both are vectors, then the output of this method will be a scalar)
"""
value = self.value if value is None else value
pvalue = norm.sf(np.abs(self.zstat(value=value)), loc=0, scale=1) * 2
return pvalue

def conf_int_mean(self, *, alpha=.1):
def conf_int_mean(self, *, alpha=None):
"""
Get the confidence interval of the mean point estimate of each treatment on each outcome for sample X.
Expand Down Expand Up @@ -1323,7 +1324,7 @@ def std_point(self):
"""
return np.std(self.pred, axis=0)

def percentile_point(self, *, alpha=.1):
def percentile_point(self, *, alpha=None):
"""
Get the confidence interval of the point estimate of each treatment on each outcome for sample X.
Expand All @@ -1346,7 +1347,7 @@ def percentile_point(self, *, alpha=.1):
upper_percentile_point = np.percentile(self.pred, (1 - alpha / 2) * 100, axis=0)
return lower_percentile_point, upper_percentile_point

def conf_int_point(self, *, alpha=.1, tol=.001):
def conf_int_point(self, *, alpha=None, tol=None):
"""
Get the confidence interval of the point estimate of each treatment on each outcome for sample X.
Expand Down Expand Up @@ -1389,7 +1390,7 @@ def stderr_point(self):
"""
return np.sqrt(self.stderr_mean**2 + self.std_point**2)

def summary(self, alpha=0.1, value=0, decimals=3, tol=0.001, output_names=None, treatment_names=None):
def summary(self, alpha=None, value=None, decimals=None, tol=None, output_names=None, treatment_names=None):
"""
Output the summary inferences above.
Expand Down
12 changes: 12 additions & 0 deletions econml/tests/test_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -288,6 +288,18 @@ def test_can_summarize(self):
inference=BootstrapInference(5)
).summary(1)

def test_alpha(self):
Y, T, X, W = TestInference.Y, TestInference.T, TestInference.X, TestInference.W
est = LinearDML(model_y=LinearRegression(), model_t=LinearRegression())
est.fit(Y, T, X=X, W=W)

# ensure alpha is passed
lb, ub = est.const_marginal_ate_interval(X, alpha=1)
assert (lb == ub).all()

lb, ub = est.const_marginal_ate_interval(X)
assert (lb != ub).all()

def test_inference_with_none_stderr(self):
Y, T, X, W = TestInference.Y, TestInference.T, TestInference.X, TestInference.W
est = DML(model_y=LinearRegression(),
Expand Down

0 comments on commit 2c140f2

Please sign in to comment.