Skip to content

Commit

Permalink
Enable explicit categories for OrthoLearner and Metalearner
Browse files Browse the repository at this point in the history
  • Loading branch information
kbattocchi committed May 8, 2020
1 parent f6f82d2 commit dad8dcd
Show file tree
Hide file tree
Showing 8 changed files with 242 additions and 68 deletions.
29 changes: 16 additions & 13 deletions econml/_ortho_learner.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ class in this module implements the general logic in a very versatile way
import numpy as np
import copy
from warnings import warn
from .utilities import (shape, reshape, ndim, hstack, cross_product, transpose,
from .utilities import (shape, reshape, ndim, hstack, cross_product, transpose, inverse_onehot,
broadcast_unit_treatments, reshape_treatmentwise_effects,
StatsModelsLinearRegression, LassoCVWrapper)
from sklearn.model_selection import KFold, StratifiedKFold, check_cv
Expand Down Expand Up @@ -277,6 +277,10 @@ class _OrthoLearner(TreatmentExpansionMixin, LinearCateEstimator):
discrete_instrument: bool
Whether the instrument values should be treated as categorical, rather than continuous, quantities
categories: 'auto' or list
The categories to use when encoding discrete treatments (or 'auto' to use the unique sorted values).
The first category will be treated as the control treatment.
n_splits: int, cross-validation generator or an iterable
Determines the cross-validation splitting strategy.
Possible inputs for cv are:
Expand Down Expand Up @@ -342,7 +346,7 @@ def score(self, Y, T, W=None, nuisances=None):
est = _OrthoLearner(ModelNuisance(LinearRegression(), LinearRegression()),
ModelFinal(),
n_splits=2, discrete_treatment=False, discrete_instrument=False,
random_state=None)
categories='auto', random_state=None)
est.fit(y, X[:, 0], W=X[:, 1:])
>>> est.score_
Expand Down Expand Up @@ -405,7 +409,7 @@ def score(self, Y, T, W=None, nuisances=None):
y = T + W[:, 0] + np.random.normal(0, 0.01, size=(100,))
est = _OrthoLearner(ModelNuisance(LogisticRegression(solver='lbfgs'), LinearRegression()),
ModelFinal(), n_splits=2, discrete_treatment=True, discrete_instrument=False,
random_state=None)
categories='auto', random_state=None)
est.fit(y, T, W=W)
>>> est.score_
Expand Down Expand Up @@ -435,7 +439,7 @@ def score(self, Y, T, W=None, nuisances=None):
"""

def __init__(self, model_nuisance, model_final, *,
discrete_treatment, discrete_instrument, n_splits, random_state):
discrete_treatment, discrete_instrument, categories, n_splits, random_state):
self._model_nuisance = clone(model_nuisance, safe=False)
self._models_nuisance = None
self._model_final = clone(model_final, safe=False)
Expand All @@ -444,8 +448,9 @@ def __init__(self, model_nuisance, model_final, *,
self._discrete_instrument = discrete_instrument
self._random_state = check_random_state(random_state)
if discrete_treatment:
self._label_encoder = LabelEncoder()
self._one_hot_encoder = OneHotEncoder(categories='auto', sparse=False)
if categories != 'auto':
categories = [categories] # OneHotEncoder expects a 2D array with features per column
self._one_hot_encoder = OneHotEncoder(categories=categories, sparse=False, drop='first')
super().__init__()

@staticmethod
Expand Down Expand Up @@ -535,14 +540,14 @@ def _fit_nuisances(self, Y, T, X=None, W=None, Z=None, sample_weight=None):
stratify = self._discrete_treatment or self._discrete_instrument

if self._discrete_treatment:
T = self._label_encoder.fit_transform(T.ravel())
T = self._one_hot_encoder.fit_transform(reshape(T, (-1, 1)))

if self._discrete_instrument:
z_enc = LabelEncoder()
Z = z_enc.fit_transform(Z.ravel())

if self._discrete_treatment: # need to stratify on combination of Z and T
to_split = T + Z * len(self._label_encoder.classes_)
to_split = inverse_onehot(T) + Z * len(self._one_hot_encoder.categories_[0])
else:
to_split = Z # just stratify on Z

Expand All @@ -554,7 +559,8 @@ def _fit_nuisances(self, Y, T, X=None, W=None, Z=None, sample_weight=None):
reshape(z_enc.transform(Z.ravel()), (-1, 1)))[:, 1:]),
validate=False)
else:
to_split = T # stratify on T if discrete, and fine to pass T as second arg to KFold.split even when not
# stratify on T if discrete, and fine to pass T as second arg to KFold.split even when not
to_split = inverse_onehot(T) if self._discrete_treatment else T
self.z_transformer = None

if self._n_splits == 1: # special case, no cross validation
Expand All @@ -574,14 +580,11 @@ def _fit_nuisances(self, Y, T, X=None, W=None, Z=None, sample_weight=None):
folds = splitter.split(np.ones((T.shape[0], 1)), to_split)

if self._discrete_treatment:
# drop first column since all columns sum to one
T = self._one_hot_encoder.fit_transform(reshape(T, (-1, 1)))[:, 1:]

self._d_t = shape(T)[1:]
self.transformer = FunctionTransformer(
func=(lambda T:
self._one_hot_encoder.transform(
reshape(self._label_encoder.transform(T.ravel()), (-1, 1)))[:, 1:]),
reshape(T, (-1, 1)))),
validate=False)

nuisances, fitted_models, fitted_inds, scores = _crossfit(self._model_nuisance, folds,
Expand Down
9 changes: 7 additions & 2 deletions econml/_rlearner.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,10 @@ class _RLearner(_OrthoLearner):
discrete_treatment: bool
Whether the treatment values should be treated as categorical, rather than continuous, quantities
categories: 'auto' or list
The categories to use when encoding discrete treatments (or 'auto' to use the unique sorted values).
The first category will be treated as the control treatment.
n_splits: int, cross-validation generator or an iterable
Determines the cross-validation splitting strategy.
Possible inputs for cv are:
Expand Down Expand Up @@ -146,7 +150,7 @@ def predict(self, X):
est = _RLearner(ModelFirst(LinearRegression()),
ModelFirst(LinearRegression()),
ModelFinal(),
n_splits=2, discrete_treatment=False, random_state=None)
n_splits=2, discrete_treatment=False, categories='auto', random_state=None)
est.fit(y, X[:, 0], X=np.ones((X.shape[0], 1)), W=X[:, 1:])
>>> est.const_marginal_effect(np.ones((1,1)))
Expand Down Expand Up @@ -197,7 +201,7 @@ def predict(self, X):
"""

def __init__(self, model_y, model_t, model_final,
discrete_treatment, n_splits, random_state):
discrete_treatment, categories, n_splits, random_state):
class ModelNuisance:
"""
Nuisance model fits the model_y and model_t at fit time and at predict time
Expand Down Expand Up @@ -276,6 +280,7 @@ def score(self, Y, T, X=None, W=None, Z=None, nuisances=None, sample_weight=None
ModelFinal(model_final),
discrete_treatment=discrete_treatment,
discrete_instrument=False, # no instrument, so doesn't matter
categories=categories,
n_splits=n_splits,
random_state=random_state)

Expand Down
40 changes: 38 additions & 2 deletions econml/dml.py
Original file line number Diff line number Diff line change
Expand Up @@ -388,6 +388,10 @@ class takes as input the parameter `model_t`, which is an arbitrary scikit-learn
discrete_treatment: bool, optional, default False
Whether the treatment values should be treated as categorical, rather than continuous, quantities
categories: 'auto' or list, default 'auto'
The categories to use when encoding discrete treatments (or 'auto' to use the unique sorted values).
The first category will be treated as the control treatment.
n_splits: int, cross-validation generator or an iterable, optional, default 2
Determines the cross-validation splitting strategy.
Possible inputs for cv are:
Expand Down Expand Up @@ -418,6 +422,7 @@ def __init__(self,
fit_cate_intercept=True,
linear_first_stages=False,
discrete_treatment=False,
categories='auto',
n_splits=2,
random_state=None):

Expand All @@ -436,6 +441,7 @@ def __init__(self,
featurizer, linear_first_stages, discrete_treatment),
model_final=_FinalWrapper(model_final, fit_cate_intercept, featurizer, False),
discrete_treatment=discrete_treatment,
categories=categories,
n_splits=n_splits,
random_state=random_state)

Expand Down Expand Up @@ -472,6 +478,10 @@ class LinearDMLCateEstimator(StatsModelsCateEstimatorMixin, DMLCateEstimator):
discrete_treatment: bool, optional (default is ``False``)
Whether the treatment values should be treated as categorical, rather than continuous, quantities
categories: 'auto' or list, default 'auto'
The categories to use when encoding discrete treatments (or 'auto' to use the unique sorted values).
The first category will be treated as the control treatment.
n_splits: int, cross-validation generator or an iterable, optional (Default=2)
Determines the cross-validation splitting strategy.
Possible inputs for cv are:
Expand Down Expand Up @@ -502,6 +512,7 @@ def __init__(self,
fit_cate_intercept=True,
linear_first_stages=True,
discrete_treatment=False,
categories='auto',
n_splits=2,
random_state=None):
super().__init__(model_y=model_y,
Expand All @@ -511,6 +522,7 @@ def __init__(self,
fit_cate_intercept=fit_cate_intercept,
linear_first_stages=linear_first_stages,
discrete_treatment=discrete_treatment,
categories=categories,
n_splits=n_splits,
random_state=random_state)

Expand Down Expand Up @@ -598,6 +610,10 @@ class SparseLinearDMLCateEstimator(DebiasedLassoCateEstimatorMixin, DMLCateEstim
discrete_treatment: bool, optional (default is ``False``)
Whether the treatment values should be treated as categorical, rather than continuous, quantities
categories: 'auto' or list, default 'auto'
The categories to use when encoding discrete treatments (or 'auto' to use the unique sorted values).
The first category will be treated as the control treatment.
n_splits: int, cross-validation generator or an iterable, optional (Default=2)
Determines the cross-validation splitting strategy.
Possible inputs for cv are:
Expand Down Expand Up @@ -630,6 +646,7 @@ def __init__(self,
fit_cate_intercept=True,
linear_first_stages=True,
discrete_treatment=False,
categories='auto',
n_splits=2,
random_state=None):
model_final = MultiOutputDebiasedLasso(
Expand All @@ -644,6 +661,7 @@ def __init__(self,
fit_cate_intercept=fit_cate_intercept,
linear_first_stages=linear_first_stages,
discrete_treatment=discrete_treatment,
categories=categories,
n_splits=n_splits,
random_state=random_state)

Expand Down Expand Up @@ -717,6 +735,10 @@ class KernelDMLCateEstimator(DMLCateEstimator):
discrete_treatment: bool, optional (default is ``False``)
Whether the treatment values should be treated as categorical, rather than continuous, quantities
categories: 'auto' or list, default 'auto'
The categories to use when encoding discrete treatments (or 'auto' to use the unique sorted values).
The first category will be treated as the control treatment.
n_splits: int, cross-validation generator or an iterable, optional (Default=2)
Determines the cross-validation splitting strategy.
Possible inputs for cv are:
Expand All @@ -741,7 +763,7 @@ class KernelDMLCateEstimator(DMLCateEstimator):
"""

def __init__(self, model_y=WeightedLassoCVWrapper(), model_t='auto', fit_cate_intercept=True,
dim=20, bw=1.0, discrete_treatment=False, n_splits=2, random_state=None):
dim=20, bw=1.0, discrete_treatment=False, categories='auto', n_splits=2, random_state=None):
class RandomFeatures(TransformerMixin):
def __init__(self, random_state):
self._random_state = check_random_state(random_state)
Expand All @@ -758,7 +780,9 @@ def transform(self, X):
model_final=ElasticNetCV(fit_intercept=False),
featurizer=RandomFeatures(random_state),
fit_cate_intercept=fit_cate_intercept,
discrete_treatment=discrete_treatment, n_splits=n_splits, random_state=random_state)
discrete_treatment=discrete_treatment,
categories=categories,
n_splits=n_splits, random_state=random_state)


class NonParamDMLCateEstimator(_BaseDMLCateEstimator):
Expand Down Expand Up @@ -790,6 +814,10 @@ class NonParamDMLCateEstimator(_BaseDMLCateEstimator):
discrete_treatment: bool, optional (default is ``False``)
Whether the treatment values should be treated as categorical, rather than continuous, quantities
categories: 'auto' or list, default 'auto'
The categories to use when encoding discrete treatments (or 'auto' to use the unique sorted values).
The first category will be treated as the control treatment.
n_splits: int, cross-validation generator or an iterable, optional (Default=2)
Determines the cross-validation splitting strategy.
Possible inputs for cv are:
Expand Down Expand Up @@ -818,6 +846,7 @@ def __init__(self,
model_y, model_t, model_final,
featurizer=None,
discrete_treatment=False,
categories='auto',
n_splits=2,
random_state=None):

Expand All @@ -830,6 +859,7 @@ def __init__(self,
featurizer, False, discrete_treatment),
model_final=_FinalWrapper(model_final, False, featurizer, True),
discrete_treatment=discrete_treatment,
categories=categories,
n_splits=n_splits,
random_state=random_state)

Expand All @@ -852,6 +882,10 @@ class ForestDMLCateEstimator(NonParamDMLCateEstimator):
discrete_treatment: bool, optional (default is ``False``)
Whether the treatment values should be treated as categorical, rather than continuous, quantities
categories: 'auto' or list, default 'auto'
The categories to use when encoding discrete treatments (or 'auto' to use the unique sorted values).
The first category will be treated as the control treatment.
n_crossfit_splits: int, cross-validation generator or an iterable, optional (Default=2)
Determines the cross-validation splitting strategy.
Possible inputs for cv are:
Expand Down Expand Up @@ -986,6 +1020,7 @@ class ForestDMLCateEstimator(NonParamDMLCateEstimator):
def __init__(self,
model_y, model_t,
discrete_treatment=False,
categories='auto',
n_crossfit_splits=2,
n_estimators=100,
criterion="mse",
Expand Down Expand Up @@ -1018,6 +1053,7 @@ def __init__(self,
super().__init__(model_y=model_y, model_t=model_t,
model_final=model_final, featurizer=None,
discrete_treatment=discrete_treatment,
categories=categories,
n_splits=n_crossfit_splits, random_state=random_state)

def _get_inference_options(self):
Expand Down
Loading

0 comments on commit dad8dcd

Please sign in to comment.