Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Vasilis/ortholearner refactor #132

Merged
merged 65 commits into from
Nov 5, 2019
Merged
Show file tree
Hide file tree
Changes from 54 commits
Commits
Show all changes
65 commits
Select commit Hold shift + click to select a range
a00da02
created ortho learner meta class. Put _Rlearner in separate file. Res…
vasilismsr Nov 2, 2019
d9449e0
notebook for ortholearner testing
vasilismsr Nov 2, 2019
8c5194d
check for non keyword argument being None in crossfit
vasilismsr Nov 2, 2019
4a87052
started tests for ortho learner
vasilismsr Nov 2, 2019
e442968
linting style for ortho learner.py
vasilismsr Nov 2, 2019
3adb748
linting style for _rlearner.py
vasilismsr Nov 2, 2019
51728b9
linting style for test_ortho_learner.py
vasilismsr Nov 2, 2019
2776d9e
lintiny
vasilismsr Nov 2, 2019
e83ded9
linting
vasilismsr Nov 2, 2019
bec8128
linting
vasilismsr Nov 2, 2019
c11d99d
fixed issues related to input output shapes. One mistake was in resid…
vasilismsr Nov 2, 2019
5cda006
added coef to sparse linear dml cate and also to LassoCVWrapper
vasilismsr Nov 2, 2019
2450a99
removed coef from within final wrapper of dmlcateestimator
vasilismsr Nov 2, 2019
fde1b7c
making ortho test deterministic
vasilismsr Nov 2, 2019
68e295c
Merge branch 'master' into vasilis/ortholearner_refactor
vasilismsr Nov 2, 2019
32e8cd4
testing noteboook
vasilismsr Nov 2, 2019
b0e8390
fixed intercept problem in model_final with multidim output
vasilismsr Nov 2, 2019
5e8e0c9
comments on the crossfit function
vasilismsr Nov 2, 2019
dc129d3
handling the case where the test folds in a custom splitter in ortho …
vasilismsr Nov 2, 2019
7679328
handling the case where the test folds in a custom splitter in ortho …
vasilismsr Nov 2, 2019
c67f745
testing notebook updates
vasilismsr Nov 2, 2019
23ef44d
updated crossfit test to include fitted_inds
vasilismsr Nov 2, 2019
de8a159
comments in ortho learner
vasilismsr Nov 2, 2019
1e3825a
documentation related changes to include the new python files.
vasilismsr Nov 3, 2019
3ddb59a
more docstring for _OrthoLearner. Some small fixes to allow for child…
vasilismsr Nov 3, 2019
9b0726c
pylint errors
vasilismsr Nov 3, 2019
23cf7c2
docstring typo
vasilismsr Nov 3, 2019
c8b6a91
docstrings
vasilismsr Nov 3, 2019
c93b8b1
docstring typo
vasilismsr Nov 3, 2019
e6e9593
many changes to the docstrings of the _OrthoLearner. Addition of scor…
vasilismsr Nov 3, 2019
8fe4de9
fixed bug in score of _rlearner with multidim outcome, to average fir…
vasilismsr Nov 3, 2019
3af072d
linting
vasilismsr Nov 4, 2019
0777705
dosctring small changes in utilities
vasilismsr Nov 4, 2019
7524cda
doc string updates to base cate estimator.
vasilismsr Nov 4, 2019
e93f7b1
better access to fitted nuisance mdoells in _OrthoLearner and _RLearn…
vasilismsr Nov 4, 2019
f8fdbb2
improved docstrings
vasilismsr Nov 4, 2019
b5d5e14
linting
vasilismsr Nov 4, 2019
96bd4eb
docstring example for RLearner
vasilismsr Nov 4, 2019
919bc06
testing notebook
vasilismsr Nov 4, 2019
de3c81b
accessing fitted models_y and models_t in DMLCateEstimator
vasilismsr Nov 4, 2019
ea7680a
notebook
vasilismsr Nov 4, 2019
187f32b
added draft example implementation of DRLearner based on the _OrthoLe…
vasilismsr Nov 4, 2019
d5bc747
improving docstring by removing unncessary highlight of code
vasilismsr Nov 5, 2019
b56048b
made sample weight and sample var keyword only in the _RLearner
vasilismsr Nov 5, 2019
5d2288f
improved docstring in cate estimator
vasilismsr Nov 5, 2019
fafbb29
added comment regarding going back from the one-hot encoding to the l…
vasilismsr Nov 5, 2019
5bd9c24
improved formatting of conditional subtraction
vasilismsr Nov 5, 2019
cfe03c1
typo in ortho learner docstring
vasilismsr Nov 5, 2019
0b7b932
docstring updated to improve on crossfit training description
vasilismsr Nov 5, 2019
c78a78d
added checks that crossfit fold structure is valid and raising approp…
vasilismsr Nov 5, 2019
42384d0
better docstring in ortho learner
vasilismsr Nov 5, 2019
c4e0cf4
W in example in ortholearner docstring
vasilismsr Nov 5, 2019
4ad02cc
simplified input checks code in ortho learner
vasilismsr Nov 5, 2019
1268bb3
updated docstring regarding how we call the split method with all the…
vasilismsr Nov 5, 2019
17f3532
added some more tests related to input None's in _crossfit
vasilismsr Nov 5, 2019
3bb63d2
added description in docstring of _param_var of StatsModelsLinearRegr…
vasilismsr Nov 5, 2019
d8d5247
docstrings in utilities
vasilismsr Nov 5, 2019
6c38a0b
removed private members from doc autosummary template.
vasilismsr Nov 5, 2019
a046515
hardcoding the autosummary template for ortho_learner and rlearner fo…
vasilismsr Nov 5, 2019
6ee1944
removed the OrthoLearner testing notebook
vasilismsr Nov 5, 2019
203f9e5
improved cate estimator docstrings
Nov 5, 2019
2628ac1
cate estimator small bug between brunches
Nov 5, 2019
66084b3
cate estimator small bug between brunches
Nov 5, 2019
dbd0f0c
linting errors
Nov 5, 2019
a27cccc
linting errors
Nov 5, 2019
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions doc/_templates/autosummary/module.rst
Original file line number Diff line number Diff line change
Expand Up @@ -3,3 +3,6 @@

.. automodule:: {{ fullname }}
:members:
:private-members:
kbattocchi marked this conversation as resolved.
Show resolved Hide resolved
:inherited-members:
:show-inheritance:
2 changes: 2 additions & 0 deletions doc/reference.rst
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@ Module reference
econml.cate_estimator
econml.deepiv
econml.dgp
econml._ortho_learner
econml._rlearner
econml.dml
econml.inference
econml.ortho_forest
Expand Down
587 changes: 587 additions & 0 deletions econml/_ortho_learner.py

Large diffs are not rendered by default.

318 changes: 318 additions & 0 deletions econml/_rlearner.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,318 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.

"""

The R Learner is an approach for estimating flexible non-parametric models
of conditional average treatment effects in the setting with no unobserved confounders.
The method is based on the idea of Neyman orthogonality and estimates a CATE
whose mean squared error is robust to the estimation errors of auxiliary submodels
that also need to be estimated from data:

1) the outcome or regression model
2) the treatment or propensity or policy or logging policy model

References
----------

Xinkun Nie, Stefan Wager (2017). Quasi-Oracle Estimation of Heterogeneous Treatment Effects.
https://arxiv.org/abs/1712.04912

Dylan Foster, Vasilis Syrgkanis (2019). Orthogonal Statistical Learning.
ACM Conference on Learning Theory. https://arxiv.org/abs/1901.09036

Chernozhukov et al. (2017). Double/debiased machine learning for treatment and structural parameters.
The Econometrics Journal. https://arxiv.org/abs/1608.00060
"""

import numpy as np
import copy
from warnings import warn
from .utilities import (shape, reshape, ndim, hstack, cross_product, transpose,
broadcast_unit_treatments, reshape_treatmentwise_effects,
StatsModelsLinearRegression, LassoCVWrapper)
from sklearn.model_selection import KFold, StratifiedKFold, check_cv
from sklearn.linear_model import LinearRegression, LassoCV
from sklearn.preprocessing import (PolynomialFeatures, LabelEncoder, OneHotEncoder,
FunctionTransformer)
from sklearn.base import clone, TransformerMixin
from sklearn.pipeline import Pipeline
from sklearn.utils import check_random_state
from .cate_estimator import (BaseCateEstimator, LinearCateEstimator,
TreatmentExpansionMixin, StatsModelsCateEstimatorMixin)
from .inference import StatsModelsInference
from ._ortho_learner import _OrthoLearner


class _RLearner(_OrthoLearner):
"""
Base class for orthogonal learners.

Parameters
----------
model_y: estimator of E[Y | X, W]
The estimator for fitting the response to the features and controls. Must implement
`fit` and `predict` methods. Unlike sklearn estimators both methods must
take an extra second argument (the controls), i.e. ::

model_y.fit(X, W, Y, sample_weight=sample_weight)
model_y.predict(X, W)

model_t: estimator of E[T | X, W]
The estimator for fitting the treatment to the features and controls. Must implement
`fit` and `predict` methods. Unlike sklearn estimators both methods must
take an extra second argument (the controls), i.e. ::

model_t.fit(X, W, T, sample_weight=sample_weight)
model_t.predict(X, W)

model_final: estimator for fitting the response residuals to the features and treatment residuals
Must implement `fit` and `predict` methods. Unlike sklearn estimators the fit methods must
take an extra second argument (the treatment residuals). Predict, on the other hand,
should just take the features and return the constant marginal effect. More, concretely::

model_final.fit(X, T_res, Y_res,
sample_weight=sample_weight, sample_var=sample_var)
model_final.predict(X)

discrete_treatment: bool
Whether the treatment values should be treated as categorical, rather than continuous, quantities

n_splits: int, cross-validation generator or an iterable, optional
Determines the cross-validation splitting strategy.
Possible inputs for cv are:

- None, to use the default 3-fold cross-validation,
- integer, to specify the number of folds.
- :term:`CV splitter`
- An iterable yielding (train, test) splits as arrays of indices.

For integer/None inputs, if the treatment is discrete
:class:`~sklearn.model_selection.StratifiedKFold` is used, else,
:class:`~sklearn.model_selection.KFold` is used
(with a random shuffle in either case).

Unless an iterable is used, we call `split(X,T)` to generate the splits.

random_state: int, :class:`~numpy.random.mtrand.RandomState` instance or None
If int, random_state is the seed used by the random number generator;
If :class:`~numpy.random.mtrand.RandomState` instance, random_state is the random number generator;
If None, the random number generator is the :class:`~numpy.random.mtrand.RandomState` instance used
by `np.random`.

Examples
--------
The example code below implements a very simple version of the double machine learning
method on top of the :py:class:`~econml._ortho_learner._RLearner` class, for expository purposes.
For a more elaborate implementation of a Double Machine Learning child class of the class
checkout :py:class:`~econml.dml.DMLCateEstimator` and its child classes::

import numpy as np
from sklearn.linear_model import LinearRegression
from econml._rlearner import _RLearner
from sklearn.base import clone
class ModelFirst:
def __init__(self, model):
self._model = clone(model, safe=False)
def fit(self, X, W, Y, sample_weight=None):
self._model.fit(np.hstack([X, W]), Y)
return self
def predict(self, X, W):
return self._model.predict(np.hstack([X, W]))
class ModelFinal:
def fit(self, X, T_res, Y_res, sample_weight=None, sample_var=None):
self.model = LinearRegression(fit_intercept=False).fit(X * T_res.reshape(-1, 1),
Y_res)
return self
def predict(self, X):
return self.model.predict(X)
np.random.seed(123)
X = np.random.normal(size=(1000, 3))
y = X[:, 0] + X[:, 1] + np.random.normal(0, 0.01, size=(1000,))
est = _RLearner(ModelFirst(LinearRegression()),
ModelFirst(LinearRegression()),
ModelFinal(),
n_splits=2, discrete_treatment=False, random_state=None)
est.fit(y, X[:, 0], X=np.ones((X.shape[0], 1)), W=X[:, 1:])

>>> est.const_marginal_effect(np.ones((1,1)))
array([0.99963147])
>>> est.effect(np.ones((1,1)), T0=0, T1=10)
array([9.99631472])
>>> est.score(y, X[:, 0], X=np.ones((X.shape[0], 1)), W=X[:, 1:])
9.736380060274913e-05
>>> est.model_final.model
LinearRegression(copy_X=True, fit_intercept=False, n_jobs=None,
normalize=False)
>>> est.model_final.model.coef_
array([0.99963147])
>>> est.score_
9.826232040878233e-05
>>> [mdl._model for mdl in est.models_y]
[LinearRegression(copy_X=True, fit_intercept=True, n_jobs=None,
normalize=False),
LinearRegression(copy_X=True, fit_intercept=True, n_jobs=None,
normalize=False)]
>>> [mdl._model for mdl in est.models_t]
[LinearRegression(copy_X=True, fit_intercept=True, n_jobs=None,
normalize=False),
LinearRegression(copy_X=True, fit_intercept=True, n_jobs=None,
normalize=False)]

Attributes
----------
models_y: list of objects of type(model_y)
A list of instances of the model_y object. Each element corresponds to a crossfitting
fold and is the model instance that was fitted for that training fold.
models_t: list of objects of type(model_t)
A list of instances of the model_t object. Each element corresponds to a crossfitting
fold and is the model instance that was fitted for that training fold.
model_final : object of type(model_final)
An instance of the model_final object that was fitted after calling fit.
score_ : float
The MSE in the final residual on residual regression, i.e.

.. math::
\\frac{1}{n} \\sum_{i=1}^n (Y_i - \\hat{E}[Y|X_i, W_i]\
- \\hat{\\theta}(X_i)\\cdot (T_i - \\hat{E}[T|X_i, W_i]))^2

If `sample_weight` is not None at fit time, then a weighted average is returned. If the outcome Y
is multidimensional, then the average of the MSEs for each dimension of Y is returned.
"""

def __init__(self, model_y, model_t, model_final,
discrete_treatment, n_splits, random_state):
class ModelNuisance:
"""
Nuisance model fits the model_y and model_t at fit time and at predict time
calculates the residual Y and residual T based on the fitted models and returns
the residuals as two nuisance parameters.
vasilismsr marked this conversation as resolved.
Show resolved Hide resolved
"""

def __init__(self, model_y, model_t):
self._model_y = clone(model_y, safe=False)
self._model_t = clone(model_t, safe=False)

def fit(self, Y, T, X=None, W=None, Z=None, sample_weight=None):
assert Z is None, "Cannot accept instrument!"
self._model_t.fit(X, W, T, sample_weight=sample_weight)
self._model_y.fit(X, W, Y, sample_weight=sample_weight)
return self

def predict(self, Y, T, X=None, W=None, Z=None, sample_weight=None):
Y_pred = self._model_y.predict(X, W)
T_pred = self._model_t.predict(X, W)
if (X is None) and (W is None): # In this case predict above returns a single row
Y_pred = np.tile(Y_pred, Y.shape[0])
T_pred = np.tile(T_pred, T.shape[0])
Y_res = Y - Y_pred.reshape(Y.shape)
T_res = T - T_pred.reshape(T.shape)
return Y_res, T_res

class ModelFinal:
"""
Final model at fit time, fits a residual on residual regression with a heterogeneous coefficient
that depends on X, i.e.

.. math ::
Y - E[Y | X, W] = \\theta(X) \\cdot (T - E[T | X, W]) + \\epsilon

and at predict time returns :math:`\\theta(X)`. The score method returns the MSE of this final
residual on residual regression.
"""

def __init__(self, model_final):
self._model_final = clone(model_final, safe=False)

def fit(self, Y, T, X=None, W=None, Z=None, nuisances=None, sample_weight=None, sample_var=None):
Y_res, T_res = nuisances
self._model_final.fit(X, T_res, Y_res, sample_weight=sample_weight, sample_var=sample_var)
return self

def predict(self, X=None):
return self._model_final.predict(X)

def score(self, Y, T, X=None, W=None, Z=None, nuisances=None, sample_weight=None, sample_var=None):
vasilismsr marked this conversation as resolved.
Show resolved Hide resolved
Y_res, T_res = nuisances
if Y_res.ndim == 1:
Y_res = Y_res.reshape((-1, 1))
if T_res.ndim == 1:
T_res = T_res.reshape((-1, 1))
effects = self._model_final.predict(X).reshape((-1, Y_res.shape[1], T_res.shape[1]))
Y_res_pred = np.einsum('ijk,ik->ij', effects, T_res).reshape(Y_res.shape)
if sample_weight is not None:
return np.mean(np.average((Y_res - Y_res_pred)**2, weights=sample_weight, axis=0))
else:
return np.mean((Y_res - Y_res_pred)**2)

super().__init__(ModelNuisance(model_y, model_t),
ModelFinal(model_final), discrete_treatment, n_splits, random_state)

def fit(self, Y, T, X=None, W=None, *, sample_weight=None, sample_var=None, inference=None):
"""
Estimate the counterfactual model from data, i.e. estimates function: math: `\\theta(\\cdot)`.

Parameters
----------
Y: (n, d_y) matrix or vector of length n
Outcomes for each sample
T: (n, d_t) matrix or vector of length n
Treatments for each sample
X: optional(n, d_x) matrix or None (Default=None)
Features for each sample
W: optional(n, d_w) matrix or None (Default=None)
Controls for each sample
sample_weight: optional(n,) vector or None (Default=None)
Weights for each samples
sample_var: optional(n,) vector or None (Default=None)
Sample variance for each sample
inference: string, `Inference` instance, or None
Method for performing inference. This estimator supports 'bootstrap'
(or an instance of `BootstrapInference`).

Returns
-------
self: _RLearner instance
"""
# Replacing fit from _OrthoLearner, to enforce Z=None and improve the docstring
return super().fit(Y, T, X=X, W=W, sample_weight=sample_weight, sample_var=sample_var, inference=inference)

def score(self, Y, T, X=None, W=None):
"""
Score the fitted CATE model on a new data set. Generates nuisance parameters
for the new data set based on the fitted residual nuisance models created at fit time.
It uses the mean prediction of the models fitted by the different crossfit folds.
Then calculates the MSE of the final residual Y on residual T regression.

If model_final does not have a score method, then it raises an `AttributeError`

Parameters
----------
Y: (n, d_y) matrix or vector of length n
Outcomes for each sample
T: (n, d_t) matrix or vector of length n
Treatments for each sample
X: optional(n, d_x) matrix or None (Default=None)
Features for each sample
W: optional(n, d_w) matrix or None (Default=None)
Controls for each sample

Returns
-------
score: float
The MSE of the final CATE model on the new data.
"""
# Replacing score from _OrthoLearner, to enforce Z=None and improve the docstring
return super().score(Y, T, X=X, W=W)

@property
def model_final(self):
return super().model_final._model_final

@property
def models_y(self):
return [mdl._model_y for mdl in super().models_nuisance]

@property
def models_t(self):
return [mdl._model_t for mdl in super().models_nuisance]
Loading