Skip to content

Commit

Permalink
add arg to allow missing values in W and sometimes X (#791)
Browse files Browse the repository at this point in the history
* enable nans in W

Signed-off-by: Fabio Vera <fabiovera@microsoft.com>

* linting

Signed-off-by: Fabio Vera <fabiovera@microsoft.com>

* add tests for ests that handle missing in W

Signed-off-by: Fabio Vera <fabiovera@microsoft.com>

* allow missing in X for some ortholearner subclasses

Signed-off-by: Fabio Vera <fabiovera@microsoft.com>

* refactor keyword arg to be bool only, add more tests

Signed-off-by: Fabio Vera <fabiovera@microsoft.com>

* linting

Signed-off-by: Fabio Vera <fabiovera@microsoft.com>

* enable missing for metalearners and orf, fix dowhywrapped discretetreat dmlorf

Signed-off-by: Fabio Vera <fabiovera@microsoft.com>

* update arg name to allow_missing, add docstrings

Signed-off-by: Fabio Vera <fabiovera@microsoft.com>

* add warning when missing values detected

Signed-off-by: Fabio Vera <fabiovera@microsoft.com>

* dummy commit

Signed-off-by: Fabio Vera <fabiovera@microsoft.com>

* dummy commit revert

Signed-off-by: Fabio Vera <fabiovera@microsoft.com>

---------

Signed-off-by: Fabio Vera <fabiovera@microsoft.com>
  • Loading branch information
fverac authored Sep 29, 2023
1 parent de4f05d commit 25c3b3b
Show file tree
Hide file tree
Showing 13 changed files with 668 additions and 90 deletions.
22 changes: 18 additions & 4 deletions econml/_ortho_learner.py
Original file line number Diff line number Diff line change
Expand Up @@ -298,6 +298,10 @@ class _OrthoLearner(TreatmentExpansionMixin, LinearCateEstimator):
How to aggregate the nuisance value for each sample across the `mc_iters` monte carlo iterations of
cross-fitting.
allow_missing: bool
Whether to allow missing values in X, W. If True, will need to supply nuisance models that can handle
missing values.
Examples
--------
Expand Down Expand Up @@ -434,7 +438,7 @@ def _gen_ortho_learner_model_final(self):
def __init__(self, *,
discrete_treatment, treatment_featurizer,
discrete_instrument, categories, cv, random_state,
mc_iters=None, mc_agg='mean'):
mc_iters=None, mc_agg='mean', allow_missing=False):
self.cv = cv
self.discrete_treatment = discrete_treatment
self.treatment_featurizer = treatment_featurizer
Expand All @@ -443,8 +447,12 @@ def __init__(self, *,
self.categories = categories
self.mc_iters = mc_iters
self.mc_agg = mc_agg
self.allow_missing = allow_missing
super().__init__()

def _gen_allowed_missing_vars(self):
return ['X', 'W'] if self.allow_missing else []

@abstractmethod
def _gen_ortho_learner_model_nuisance(self):
""" Must return a fresh instance of a nuisance model
Expand Down Expand Up @@ -605,8 +613,12 @@ def fit(self, Y, T, *, X=None, W=None, Z=None, sample_weight=None, freq_weight=N
assert not (self.discrete_treatment and self.treatment_featurizer), "Treatment featurization " \
"is not supported when treatment is discrete"
if check_input:
Y, T, X, W, Z, sample_weight, freq_weight, sample_var, groups = check_input_arrays(
Y, T, X, W, Z, sample_weight, freq_weight, sample_var, groups)
Y, T, Z, sample_weight, freq_weight, sample_var, groups = check_input_arrays(
Y, T, Z, sample_weight, freq_weight, sample_var, groups)
X, = check_input_arrays(
X, force_all_finite='allow-nan' if 'X' in self._gen_allowed_missing_vars() else True)
W, = check_input_arrays(
W, force_all_finite='allow-nan' if 'W' in self._gen_allowed_missing_vars() else True)
self._check_input_dims(Y, T, X, W, Z, sample_weight, freq_weight, sample_var, groups)

if not only_final:
Expand Down Expand Up @@ -878,7 +890,9 @@ def score(self, Y, T, X=None, W=None, Z=None, sample_weight=None, groups=None):
"""
if not hasattr(self._ortho_learner_model_final, 'score'):
raise AttributeError("Final model does not have a score method!")
Y, T, X, W, Z = check_input_arrays(Y, T, X, W, Z)
Y, T, Z = check_input_arrays(Y, T, Z)
X, = check_input_arrays(X, force_all_finite='allow-nan' if 'X' in self._gen_allowed_missing_vars() else True)
W, = check_input_arrays(W, force_all_finite='allow-nan' if 'W' in self._gen_allowed_missing_vars() else True)
self._check_fitted_dims(X)
self._check_fitted_dims_w_z(W, Z)
X, T = self._expand_treatments(X, T)
Expand Down
9 changes: 7 additions & 2 deletions econml/dml/_rlearner.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,6 +184,10 @@ class _RLearner(_OrthoLearner):
How to aggregate the nuisance value for each sample across the `mc_iters` monte carlo iterations of
cross-fitting.
allow_missing: bool
Whether to allow missing values in X, W. If True, will need to supply nuisance models that can handle
missing values.
Examples
--------
Expand Down Expand Up @@ -272,15 +276,16 @@ def _gen_rlearner_model_final(self):
"""

def __init__(self, *, discrete_treatment, treatment_featurizer, categories,
cv, random_state, mc_iters=None, mc_agg='mean'):
cv, random_state, mc_iters=None, mc_agg='mean', allow_missing=False):
super().__init__(discrete_treatment=discrete_treatment,
treatment_featurizer=treatment_featurizer,
discrete_instrument=False, # no instrument, so doesn't matter
categories=categories,
cv=cv,
random_state=random_state,
mc_iters=mc_iters,
mc_agg=mc_agg)
mc_agg=mc_agg,
allow_missing=allow_missing)

@abstractmethod
def _gen_model_y(self):
Expand Down
16 changes: 13 additions & 3 deletions econml/dml/causal_forest.py
Original file line number Diff line number Diff line change
Expand Up @@ -514,6 +514,10 @@ class CausalForestDML(_BaseDML):
verbose : int, default 0
Controls the verbosity when fitting and predicting.
allow_missing: bool
Whether to allow missing values in W. If True, will need to supply model_y, model_y that can handle
missing values.
Examples
--------
A simple example with the default models and discrete treatment:
Expand Down Expand Up @@ -601,7 +605,8 @@ def __init__(self, *,
subforest_size=4,
n_jobs=-1,
random_state=None,
verbose=0):
verbose=0,
allow_missing=False):

# TODO: consider whether we need more care around stateful featurizers,
# since we clone it and fit separate copies
Expand Down Expand Up @@ -636,7 +641,11 @@ def __init__(self, *,
cv=cv,
mc_iters=mc_iters,
mc_agg=mc_agg,
random_state=random_state)
random_state=random_state,
allow_missing=allow_missing)

def _gen_allowed_missing_vars(self):
return ['W'] if self.allow_missing else []

def _get_inference_options(self):
options = super()._get_inference_options()
Expand Down Expand Up @@ -737,7 +746,8 @@ def tune(self, Y, T, *, X=None, W=None,
all parameters of the object have been set to the best performing parameters from the tuning grid.
"""
from ..score import RScorer # import here to avoid circular import issue
Y, T, X, W, sample_weight, groups = check_input_arrays(Y, T, X, W, sample_weight, groups)
Y, T, X, sample_weight, groups = check_input_arrays(Y, T, X, sample_weight, groups)
W, = check_input_arrays(W, force_all_finite='allow-nan' if 'W' in self._gen_allowed_missing_vars() else True)

if params == 'auto':
params = {
Expand Down
68 changes: 57 additions & 11 deletions econml/dml/dml.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,8 +108,9 @@ def __init__(self, model_final, fit_cate_intercept, featurizer, use_weight_trick
else:
self._fit_cate_intercept = fit_cate_intercept
if self._fit_cate_intercept:
# data is already validated at initial fit time
add_intercept_trans = FunctionTransformer(add_intercept,
validate=True)
validate=False)
if featurizer:
self._featurizer = Pipeline([('featurize', self._original_featurizer),
('add_intercept', add_intercept_trans)])
Expand Down Expand Up @@ -410,6 +411,10 @@ class takes as input the parameter `model_t`, which is an arbitrary scikit-learn
If None, the random number generator is the :class:`~numpy.random.mtrand.RandomState` instance used
by :mod:`np.random<numpy.random>`.
allow_missing: bool
Whether to allow missing values in X, W. If True, will need to supply model_y, model_t, and model_final
that can handle missing values.
Examples
--------
A simple example with discrete treatment and a linear model_final (equivalent to LinearDML):
Expand Down Expand Up @@ -466,7 +471,8 @@ def __init__(self, *,
cv=2,
mc_iters=None,
mc_agg='mean',
random_state=None):
random_state=None,
allow_missing=False):
# TODO: consider whether we need more care around stateful featurizers,
# since we clone it and fit separate copies
self.fit_cate_intercept = fit_cate_intercept
Expand All @@ -481,7 +487,11 @@ def __init__(self, *,
cv=cv,
mc_iters=mc_iters,
mc_agg=mc_agg,
random_state=random_state)
random_state=random_state,
allow_missing=allow_missing)

def _gen_allowed_missing_vars(self):
return ['X', 'W'] if self.allow_missing else []

def _gen_featurizer(self):
return clone(self.featurizer, safe=False)
Expand Down Expand Up @@ -642,6 +652,10 @@ class LinearDML(StatsModelsCateEstimatorMixin, DML):
If None, the random number generator is the :class:`~numpy.random.mtrand.RandomState` instance used
by :mod:`np.random<numpy.random>`.
allow_missing: bool
Whether to allow missing values in W. If True, will need to supply model_y, model_t that can handle
missing values.
Examples
--------
A simple example with the default models and discrete treatment:
Expand Down Expand Up @@ -692,7 +706,8 @@ def __init__(self, *,
cv=2,
mc_iters=None,
mc_agg='mean',
random_state=None):
random_state=None,
allow_missing=False):
super().__init__(model_y=model_y,
model_t=model_t,
model_final=None,
Expand All @@ -705,7 +720,11 @@ def __init__(self, *,
cv=cv,
mc_iters=mc_iters,
mc_agg=mc_agg,
random_state=random_state,)
random_state=random_state,
allow_missing=allow_missing)

def _gen_allowed_missing_vars(self):
return ['W'] if self.allow_missing else []

def _gen_model_final(self):
return StatsModelsLinearRegression(fit_intercept=False)
Expand Down Expand Up @@ -876,6 +895,10 @@ class SparseLinearDML(DebiasedLassoCateEstimatorMixin, DML):
If None, the random number generator is the :class:`~numpy.random.mtrand.RandomState` instance used
by :mod:`np.random<numpy.random>`.
allow_missing: bool
Whether to allow missing values in W. If True, will need to supply model_y, model_t that can handle
missing values.
Examples
--------
A simple example with the default models and discrete treatment:
Expand Down Expand Up @@ -932,7 +955,8 @@ def __init__(self, *,
cv=2,
mc_iters=None,
mc_agg='mean',
random_state=None):
random_state=None,
allow_missing=False):
self.alpha = alpha
self.n_alphas = n_alphas
self.alpha_cov = alpha_cov
Expand All @@ -952,7 +976,11 @@ def __init__(self, *,
cv=cv,
mc_iters=mc_iters,
mc_agg=mc_agg,
random_state=random_state)
random_state=random_state,
allow_missing=allow_missing)

def _gen_allowed_missing_vars(self):
return ['W'] if self.allow_missing else []

def _gen_model_final(self):
return MultiOutputDebiasedLasso(alpha=self.alpha,
Expand Down Expand Up @@ -1104,6 +1132,10 @@ class KernelDML(DML):
If None, the random number generator is the :class:`~numpy.random.mtrand.RandomState` instance used
by :mod:`np.random<numpy.random>`.
allow_missing: bool
Whether to allow missing values in W. If True, will need to supply model_y, model_t that can handle
missing values.
Examples
--------
A simple example with the default models and discrete treatment:
Expand Down Expand Up @@ -1139,7 +1171,8 @@ def __init__(self, model_y='auto', model_t='auto',
bw=1.0,
cv=2,
mc_iters=None, mc_agg='mean',
random_state=None):
random_state=None,
allow_missing=False):
self.dim = dim
self.bw = bw
super().__init__(model_y=model_y,
Expand All @@ -1153,7 +1186,11 @@ def __init__(self, model_y='auto', model_t='auto',
cv=cv,
mc_iters=mc_iters,
mc_agg=mc_agg,
random_state=random_state)
random_state=random_state,
allow_missing=allow_missing)

def _gen_allowed_missing_vars(self):
return ['W'] if self.allow_missing else []

def _gen_model_final(self):
return ElasticNetCV(fit_intercept=False, random_state=self.random_state)
Expand Down Expand Up @@ -1285,6 +1322,10 @@ class NonParamDML(_BaseDML):
If None, the random number generator is the :class:`~numpy.random.mtrand.RandomState` instance used
by :mod:`np.random<numpy.random>`.
allow_missing: bool
Whether to allow missing values in W. If True, will need to supply model_y, model_t, and model_final
that can handle missing values.
Examples
--------
A simple example with a discrete treatment:
Expand Down Expand Up @@ -1326,7 +1367,8 @@ def __init__(self, *,
cv=2,
mc_iters=None,
mc_agg='mean',
random_state=None):
random_state=None,
allow_missing=False):

# TODO: consider whether we need more care around stateful featurizers,
# since we clone it and fit separate copies
Expand All @@ -1340,7 +1382,11 @@ def __init__(self, *,
cv=cv,
mc_iters=mc_iters,
mc_agg=mc_agg,
random_state=random_state)
random_state=random_state,
allow_missing=allow_missing)

def _gen_allowed_missing_vars(self):
return ['X', 'W'] if self.allow_missing else []

def _get_inference_options(self):
# add blb to parent's options
Expand Down
8 changes: 7 additions & 1 deletion econml/dowhy.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,13 @@ def fit(self, Y, T, X=None, W=None, Z=None, *, outcome_names=None, treatment_nam
column_names = outcome_names + treatment_names + feature_names + confounder_names + instrument_names

# transfer input to numpy arrays
Y, T, X, W, Z = check_input_arrays(Y, T, X, W, Z)
if 'X' in self._cate_estimator._gen_allowed_missing_vars():
raise ValueError(
'DoWhyWrapper does not support missing values in X. Please set allow_missing=False before proceeding.'
)
Y, T, X, Z = check_input_arrays(Y, T, X, Z)
W, = check_input_arrays(
W, force_all_finite='allow-nan' if 'W' in self._cate_estimator._gen_allowed_missing_vars() else True)
# transfer input to 2d arrays
n_obs = Y.shape[0]
Y, T, X, W, Z = reshape_arrays_2dim(n_obs, Y, T, X, W, Z)
Expand Down
Loading

0 comments on commit 25c3b3b

Please sign in to comment.