Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

CATE validation - uplift uniform confidence bands #840

Merged
merged 13 commits into from
Mar 19, 2024
14 changes: 14 additions & 0 deletions doc/reference.rst
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,20 @@ CATE Interpreters
econml.cate_interpreter.SingleTreeCateInterpreter
econml.cate_interpreter.SingleTreePolicyInterpreter

.. _validation_api:

CATE Validation
---------------

.. autosummary::
:toctree: _autosummary

econml.validate.DRTester
econml.validate.BLPEvaluationResults
econml.validate.CalibrationEvaluationResults
econml.validate.UpliftEvaluationResults
econml.validate.EvaluationResults

.. _scorers_api:

CATE Scorers
Expand Down
48 changes: 27 additions & 21 deletions econml/tests/test_drtester.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import scipy.stats as st
from sklearn.ensemble import RandomForestClassifier, GradientBoostingRegressor

from econml.validate.drtester import DRtester
from econml.validate.drtester import DRTester
from econml.dml import DML


Expand Down Expand Up @@ -70,7 +70,7 @@ def test_multi(self):
).fit(Y=Ytrain, T=Dtrain, X=Xtrain)

# test the DR outcome difference
my_dr_tester = DRtester(
my_dr_tester = DRTester(
model_regression=reg_y,
model_propensity=reg_t,
cate=cate
Expand Down Expand Up @@ -123,7 +123,7 @@ def test_binary(self):
).fit(Y=Ytrain, T=Dtrain, X=Xtrain)

# test the DR outcome difference
my_dr_tester = DRtester(
my_dr_tester = DRTester(
model_regression=reg_y,
model_propensity=reg_t,
cate=cate
Expand All @@ -148,8 +148,8 @@ def test_binary(self):
self.assertRaises(ValueError, res.plot_toc, k)
else: # real treatment, k = 1
self.assertTrue(res.plot_cal(k) is not None)
self.assertTrue(res.plot_qini(k) is not None)
self.assertTrue(res.plot_toc(k) is not None)
self.assertTrue(res.plot_qini(k, 'ucb2') is not None)
self.assertTrue(res.plot_toc(k, 'ucb1') is not None)

self.assertLess(res_df.blp_pval.values[0], 0.05) # heterogeneity
self.assertGreater(res_df.cal_r_squared.values[0], 0) # good R2
Expand All @@ -171,7 +171,7 @@ def test_nuisance_val_fit(self):
).fit(Y=Ytrain, T=Dtrain, X=Xtrain)

# test the DR outcome difference
my_dr_tester = DRtester(
my_dr_tester = DRTester(
model_regression=reg_y,
model_propensity=reg_t,
cate=cate
Expand All @@ -193,8 +193,8 @@ def test_nuisance_val_fit(self):
for kwargs in [{}, {'Xval': Xval}]:
with self.assertRaises(Exception) as exc:
my_dr_tester.evaluate_cal(kwargs)
self.assertTrue(
str(exc.exception) == "Must fit nuisance models on training sample data to use calibration test"
self.assertEqual(
str(exc.exception), "Must fit nuisance models on training sample data to use calibration test"
)

def test_exceptions(self):
Expand All @@ -212,7 +212,7 @@ def test_exceptions(self):
).fit(Y=Ytrain, T=Dtrain, X=Xtrain)

# test the DR outcome difference
my_dr_tester = DRtester(
my_dr_tester = DRTester(
model_regression=reg_y,
model_propensity=reg_t,
cate=cate
Expand All @@ -223,11 +223,11 @@ def test_exceptions(self):
with self.assertRaises(Exception) as exc:
func()
if func.__name__ == 'evaluate_cal':
self.assertTrue(
str(exc.exception) == "Must fit nuisance models on training sample data to use calibration test"
self.assertEqual(
str(exc.exception), "Must fit nuisance models on training sample data to use calibration test"
)
else:
self.assertTrue(str(exc.exception) == "Must fit nuisances before evaluating")
self.assertEqual(str(exc.exception), "Must fit nuisances before evaluating")

my_dr_tester = my_dr_tester.fit_nuisance(
Xval, Dval, Yval, Xtrain, Dtrain, Ytrain
Expand All @@ -242,12 +242,12 @@ def test_exceptions(self):
with self.assertRaises(Exception) as exc:
func()
if func.__name__ == 'evaluate_blp':
self.assertTrue(
str(exc.exception) == "CATE predictions not yet calculated - must provide Xval"
self.assertEqual(
str(exc.exception), "CATE predictions not yet calculated - must provide Xval"
)
else:
self.assertTrue(str(exc.exception) ==
"CATE predictions not yet calculated - must provide both Xval, Xtrain")
self.assertEqual(str(exc.exception),
"CATE predictions not yet calculated - must provide both Xval, Xtrain")

for func in [
my_dr_tester.evaluate_cal,
Expand All @@ -256,19 +256,19 @@ def test_exceptions(self):
]:
with self.assertRaises(Exception) as exc:
func(Xval=Xval)
self.assertTrue(
str(exc.exception) == "CATE predictions not yet calculated - must provide both Xval, Xtrain")
self.assertEqual(
str(exc.exception), "CATE predictions not yet calculated - must provide both Xval, Xtrain")

cal_res = my_dr_tester.evaluate_cal(Xval, Xtrain)
self.assertGreater(cal_res.cal_r_squared[0], 0) # good R2

with self.assertRaises(Exception) as exc:
my_dr_tester.evaluate_uplift(metric='blah')
self.assertTrue(
str(exc.exception) == "Unsupported metric - must be one of ['toc', 'qini']"
self.assertEqual(
str(exc.exception), "Unsupported metric 'blah' - must be one of ['toc', 'qini']"
)

my_dr_tester = DRtester(
my_dr_tester = DRTester(
model_regression=reg_y,
model_propensity=reg_t,
cate=cate
Expand All @@ -278,5 +278,11 @@ def test_exceptions(self):
qini_res = my_dr_tester.evaluate_uplift(Xval, Xtrain)
self.assertLess(qini_res.pvals[0], 0.05)

with self.assertRaises(Exception) as exc:
qini_res.plot_uplift(tmt=1, err_type='blah')
self.assertEqual(
str(exc.exception), "Invalid error type 'blah'; must be one of [None, 'ucb2', 'ucb1']"
)

autoc_res = my_dr_tester.evaluate_uplift(Xval, Xtrain, metric='toc')
self.assertLess(autoc_res.pvals[0], 0.05)
6 changes: 4 additions & 2 deletions econml/validate/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,9 @@
A suite of validation methods for CATE models.
"""

from .drtester import DRtester
from .drtester import DRTester
from .results import BLPEvaluationResults, CalibrationEvaluationResults, UpliftEvaluationResults, EvaluationResults


__all__ = ['DRtester']
__all__ = ['DRTester',
'BLPEvaluationResults', 'CalibrationEvaluationResults', 'UpliftEvaluationResults', 'EvaluationResults']
Loading
Loading