From 22eced600665184242f79a8b77aa0d6778e1848c Mon Sep 17 00:00:00 2001 From: Miruna Oprescu Date: Fri, 6 Mar 2020 13:36:22 -0500 Subject: [PATCH 1/2] Fixed SparseDML and bootstrap bug. --- econml/dml.py | 7 +++++-- econml/tests/test_dml.py | 3 ++- 2 files changed, 7 insertions(+), 3 deletions(-) diff --git a/econml/dml.py b/econml/dml.py index e5153e6fa..05672355c 100644 --- a/econml/dml.py +++ b/econml/dml.py @@ -633,7 +633,7 @@ def __init__(self, n_splits=n_splits, random_state=random_state) - def fit(self, Y, T, X=None, W=None, sample_weight=None, inference=None): + def fit(self, Y, T, X=None, W=None, sample_weight=None, sample_var=None, inference=None): """ Estimate the counterfactual model from data, i.e. estimates functions τ(·,·,·), ∂τ(·,·). @@ -649,6 +649,9 @@ def fit(self, Y, T, X=None, W=None, sample_weight=None, inference=None): Controls for each sample sample_weight: optional (n,) vector Weights for each row + sample_var: optional (n, n_y) vector + Variance of sample, in case it corresponds to summary of many samples. Currently + not in use by this method but will be supported in a future release. inference: string, `Inference` instance, or None Method for performing inference. This estimator supports 'bootstrap' (or an instance of :class:`.BootstrapInference`) and 'debiasedlasso' @@ -659,7 +662,7 @@ def fit(self, Y, T, X=None, W=None, sample_weight=None, inference=None): self """ # TODO: support sample_var - if sample_weight is not None and inference is not None: + if sample_var is not None and inference is not None: warn("This estimator does not yet support sample variances and inference does not take " "sample variances into account. This feature will be supported in a future release.") check_high_dimensional(X, T, threshold=5, featurizer=self.featurizer, diff --git a/econml/tests/test_dml.py b/econml/tests/test_dml.py index 1420d5f84..72e928d61 100644 --- a/econml/tests/test_dml.py +++ b/econml/tests/test_dml.py @@ -117,7 +117,8 @@ def make_random(is_discrete, d): fit_cate_intercept=fit_cate_intercept, discrete_treatment=is_discrete), True, - [None, 'debiasedlasso']), + [None, 'debiasedlasso'] + + [BootstrapInference(n_bootstrap_samples=20)] if not is_discrete else []), (KernelDMLCateEstimator(model_y=WeightedLasso(), model_t=model_t, fit_cate_intercept=fit_cate_intercept, From 7ff334e305f616a2e8f23a92db3962dc3c5bf173 Mon Sep 17 00:00:00 2001 From: Miruna Oprescu Date: Fri, 6 Mar 2020 14:20:32 -0500 Subject: [PATCH 2/2] Updated dml tests. --- econml/tests/test_dml.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/econml/tests/test_dml.py b/econml/tests/test_dml.py index 72e928d61..4bc9ad53b 100644 --- a/econml/tests/test_dml.py +++ b/econml/tests/test_dml.py @@ -118,7 +118,7 @@ def make_random(is_discrete, d): discrete_treatment=is_discrete), True, [None, 'debiasedlasso'] + - [BootstrapInference(n_bootstrap_samples=20)] if not is_discrete else []), + ([BootstrapInference(n_bootstrap_samples=20)] if not is_discrete else [])), (KernelDMLCateEstimator(model_y=WeightedLasso(), model_t=model_t, fit_cate_intercept=fit_cate_intercept,