Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
112 changes: 64 additions & 48 deletions doubleml/irm/apo.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,7 +183,7 @@ def weights(self):
return self._weights

def _initialize_ml_nuisance_params(self):
valid_learner = ["ml_g0", "ml_g1", "ml_m"]
valid_learner = ["ml_g_d_lvl0", "ml_g_d_lvl1", "ml_m"]
self._params = {learner: {key: [None] * self.n_rep for key in self._dml_data.d_cols} for learner in valid_learner}

def _initialize_weights(self, weights):
Expand All @@ -207,64 +207,68 @@ def _get_weights(self):

def _nuisance_est(self, smpls, n_jobs_cv, external_predictions, return_models=False):
x, y = check_X_y(self._dml_data.x, self._dml_data.y, force_all_finite=False)
x, d = check_X_y(x, self._dml_data.d, force_all_finite=False)
dx = np.column_stack((d, x))
# use the treated indicator to get the correct sample splits
x, treated = check_X_y(x, self.treated, force_all_finite=False)
treated = self.treated

# get train indices for d == treatment_level
smpls_d0, smpls_d1 = _get_cond_smpls(smpls, treated)
g0_external = external_predictions["ml_g0"] is not None
g1_external = external_predictions["ml_g1"] is not None
g_d_lvl0_external = external_predictions["ml_g_d_lvl0"] is not None
g_d_lvl1_external = external_predictions["ml_g_d_lvl1"] is not None
m_external = external_predictions["ml_m"] is not None

# nuisance g (g0 only relevant for sensitivity analysis)
if g0_external:
# nuisance g_d_lvl1 (relevant for score as (average) counterfactuals)
if g_d_lvl1_external:
# use external predictions
g_hat0 = {
"preds": external_predictions["ml_g0"],
"targets": _cond_targets(y, cond_sample=(treated == 0)),
g_hat_d_lvl1 = {
"preds": external_predictions["ml_g_d_lvl1"],
"targets": _cond_targets(y, cond_sample=(treated == 1)),
"models": None,
}
else:
g_hat0 = _dml_cv_predict(
g_hat_d_lvl1 = _dml_cv_predict(
self._learner["ml_g"],
x,
y,
smpls=smpls_d0,
smpls=smpls_d1,
n_jobs=n_jobs_cv,
est_params=self._get_params("ml_g0"),
est_params=self._get_params("ml_g_d_lvl1"),
method=self._predict_method["ml_g"],
return_models=return_models,
)
_check_finite_predictions(g_hat0["preds"], self._learner["ml_g"], "ml_g", smpls)
g_hat0["targets"] = _cond_targets(g_hat0["targets"], cond_sample=(treated == 0))
_check_finite_predictions(g_hat_d_lvl1["preds"], self._learner["ml_g"], "ml_g", smpls)
# adjust target values to consider only compatible subsamples
g_hat_d_lvl1["targets"] = _cond_targets(g_hat_d_lvl1["targets"], cond_sample=(treated == 1))

if self._dml_data.binary_outcome:
_check_binary_predictions(g_hat0["preds"], self._learner["ml_g"], "ml_g", self._dml_data.y_col)
_check_binary_predictions(g_hat_d_lvl1["preds"], self._learner["ml_g"], "ml_g", self._dml_data.y_col)

if g1_external:
# nuisance g (g for other treatment levels only relevant for sensitivity analysis)
if g_d_lvl0_external:
# use external predictions
g_hat1 = {
"preds": external_predictions["ml_g1"],
"targets": _cond_targets(y, cond_sample=(treated == 1)),
g_hat_d_lvl0 = {
"preds": external_predictions["ml_g_d_lvl0"],
"targets": _cond_targets(y, cond_sample=(treated == 0)),
"models": None,
}
else:
g_hat1 = _dml_cv_predict(
g_hat_d_lvl0 = _dml_cv_predict(
self._learner["ml_g"],
x,
dx, # used to obtain an estimation over several treatment levels (reduced variance in sensitivity)
y,
smpls=smpls_d1,
smpls=smpls_d0,
n_jobs=n_jobs_cv,
est_params=self._get_params("ml_g1"),
est_params=self._get_params("ml_g_d_lvl0"),
method=self._predict_method["ml_g"],
return_models=return_models,
)
_check_finite_predictions(g_hat1["preds"], self._learner["ml_g"], "ml_g", smpls)
_check_finite_predictions(g_hat_d_lvl0["preds"], self._learner["ml_g"], "ml_g", smpls)
# adjust target values to consider only compatible subsamples
g_hat1["targets"] = _cond_targets(g_hat1["targets"], cond_sample=(treated == 1))
g_hat_d_lvl0["targets"] = _cond_targets(g_hat_d_lvl0["targets"], cond_sample=(treated == 0))

if self._dml_data.binary_outcome:
_check_binary_predictions(g_hat1["preds"], self._learner["ml_g"], "ml_g", self._dml_data.y_col)
_check_binary_predictions(g_hat_d_lvl0["preds"], self._learner["ml_g"], "ml_g", self._dml_data.y_col)

# nuisance m
if m_external:
Expand All @@ -287,25 +291,33 @@ def _nuisance_est(self, smpls, n_jobs_cv, external_predictions, return_models=Fa
# also trimm external predictions
m_hat["preds"] = _trimm(m_hat["preds"], self.trimming_rule, self.trimming_threshold)

psi_a, psi_b = self._score_elements(y, treated, g_hat0["preds"], g_hat1["preds"], m_hat["preds"], smpls)
psi_a, psi_b = self._score_elements(y, treated, g_hat_d_lvl0["preds"], g_hat_d_lvl1["preds"], m_hat["preds"], smpls)
psi_elements = {"psi_a": psi_a, "psi_b": psi_b}

preds = {
"predictions": {"ml_g0": g_hat0["preds"], "ml_g1": g_hat1["preds"], "ml_m": m_hat["preds"]},
"targets": {"ml_g0": g_hat0["targets"], "ml_g1": g_hat1["targets"], "ml_m": m_hat["targets"]},
"models": {"ml_g0": g_hat0["models"], "ml_g1": g_hat1["models"], "ml_m": m_hat["models"]},
"predictions": {
"ml_g_d_lvl0": g_hat_d_lvl0["preds"],
"ml_g_d_lvl1": g_hat_d_lvl1["preds"],
"ml_m": m_hat["preds"],
},
"targets": {
"ml_g_d_lvl0": g_hat_d_lvl0["targets"],
"ml_g_d_lvl1": g_hat_d_lvl1["targets"],
"ml_m": m_hat["targets"],
},
"models": {"ml_g_d_lvl0": g_hat_d_lvl0["models"], "ml_g_d_lvl1": g_hat_d_lvl1["models"], "ml_m": m_hat["models"]},
}
return psi_elements, preds

def _score_elements(self, y, treated, g_hat0, g_hat1, m_hat, smpls):
def _score_elements(self, y, treated, g_hat_d_lvl0, g_hat_d_lvl1, m_hat, smpls):
if self.normalize_ipw:
m_hat_adj = _normalize_ipw(m_hat, treated)
else:
m_hat_adj = m_hat

u_hat = y - g_hat1
u_hat = y - g_hat_d_lvl1
weights, weights_bar = self._get_weights()
psi_b = weights * g_hat1 + weights_bar * np.divide(np.multiply(treated, u_hat), m_hat_adj)
psi_b = weights * g_hat_d_lvl1 + weights_bar * np.divide(np.multiply(treated, u_hat), m_hat_adj)
psi_a = -1.0 * np.divide(weights, np.mean(weights)) # TODO: check if this is correct

return psi_a, psi_b
Expand All @@ -316,12 +328,12 @@ def _sensitivity_element_est(self, preds):
treated = self.treated

m_hat = preds["predictions"]["ml_m"]
g_hat0 = preds["predictions"]["ml_g0"]
g_hat1 = preds["predictions"]["ml_g1"]
g_hat_d_lvl0 = preds["predictions"]["ml_g_d_lvl0"]
g_hat_d_lvl1 = preds["predictions"]["ml_g_d_lvl1"]

weights, weights_bar = self._get_weights()

sigma2_score_element = np.square(y - np.multiply(treated, g_hat1) - np.multiply(1.0 - treated, g_hat0))
sigma2_score_element = np.square(y - np.multiply(treated, g_hat_d_lvl1) - np.multiply(1.0 - treated, g_hat_d_lvl0))
sigma2 = np.mean(sigma2_score_element)
psi_sigma2 = sigma2_score_element - sigma2

Expand All @@ -346,20 +358,24 @@ def _nuisance_tuning(
self, smpls, param_grids, scoring_methods, n_folds_tune, n_jobs_cv, search_mode, n_iter_randomized_search
):
x, y = check_X_y(self._dml_data.x, self._dml_data.y, force_all_finite=False)
x, treated = check_X_y(x, self.treated, force_all_finite=False)
x, d = check_X_y(x, self._dml_data.d, force_all_finite=False)
dx = np.column_stack((d, x))
# use the treated indicator to get the correct sample splits
treated = self.treated

# get train indices for d == 0 and d == 1
smpls_d0, smpls_d1 = _get_cond_smpls(smpls, treated)

if scoring_methods is None:
scoring_methods = {"ml_g": None, "ml_m": None}

train_inds = [train_index for (train_index, _) in smpls]
train_inds_d0 = [train_index for (train_index, _) in smpls_d0]
train_inds_d1 = [train_index for (train_index, _) in smpls_d1]
g0_tune_res = _dml_tune(
train_inds_d_lvl0 = [train_index for (train_index, _) in smpls_d0]
train_inds_d_lvl1 = [train_index for (train_index, _) in smpls_d1]
g_d_lvl0_tune_res = _dml_tune(
y,
x,
train_inds_d0,
dx, # used to obtain an estimation over several treatment levels (reduced variance in sensitivity)
train_inds_d_lvl0,
self._learner["ml_g"],
param_grids["ml_g"],
scoring_methods["ml_g"],
Expand All @@ -368,10 +384,10 @@ def _nuisance_tuning(
search_mode,
n_iter_randomized_search,
)
g1_tune_res = _dml_tune(
g_d_lvl1_tune_res = _dml_tune(
y,
x,
train_inds_d1,
train_inds_d_lvl1,
self._learner["ml_g"],
param_grids["ml_g"],
scoring_methods["ml_g"],
Expand All @@ -394,12 +410,12 @@ def _nuisance_tuning(
n_iter_randomized_search,
)

g0_best_params = [xx.best_params_ for xx in g0_tune_res]
g1_best_params = [xx.best_params_ for xx in g1_tune_res]
g_d_lvl0_best_params = [xx.best_params_ for xx in g_d_lvl0_tune_res]
g_d_lvl1_best_params = [xx.best_params_ for xx in g_d_lvl1_tune_res]
m_best_params = [xx.best_params_ for xx in m_tune_res]

params = {"ml_g0": g0_best_params, "ml_g1": g1_best_params, "ml_m": m_best_params}
tune_res = {"g0_tune": g0_tune_res, "g1_tune": g1_tune_res, "m_tune": m_tune_res}
params = {"ml_g_d_lvl0": g_d_lvl0_best_params, "ml_g_d_lvl1": g_d_lvl1_best_params, "ml_m": m_best_params}
tune_res = {"g_d_lvl0_tune": g_d_lvl0_tune_res, "g_d_lvl1_tune": g_d_lvl1_tune_res, "m_tune": m_tune_res}

res = {"params": params, "tune_res": tune_res}

Expand Down
74 changes: 57 additions & 17 deletions doubleml/irm/apos.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from ..double_ml_framework import concat
from ..utils._checks import _check_sample_splitting, _check_score, _check_trimming, _check_weights
from ..utils._descriptive import generate_summary
from ..utils._sensitivity import _compute_sensitivity_bias
from ..utils.gain_statistics import gain_statistics
from ..utils.resampling import DoubleMLResampling
from .apo import DoubleMLAPO
Expand Down Expand Up @@ -733,22 +734,29 @@ def causal_contrast(self, reference_levels):
"a single treatment level."
)

skip_index = []
skip_index = set()
all_treatment_names = []
all_acc_frameworks = []

for ref_lvl in reference_levels:
i_ref_lvl = self.treatment_levels.index(ref_lvl)
ref_framework = self.modellist[i_ref_lvl].framework

skip_index += [i_ref_lvl]
all_acc_frameworks += [
model.framework - ref_framework for i, model in enumerate(self.modellist) if i not in skip_index
]
all_treatment_names += [
f"{self.treatment_levels[i]} vs {self.treatment_levels[i_ref_lvl]}"
for i in range(self.n_treatment_levels)
if i not in skip_index
]
ref_model = self.modellist[i_ref_lvl]

skip_index.add(i_ref_lvl)
for i, model in enumerate(self.modellist):
# only comparisons which are not yet computed
if i in skip_index:
continue

current_framework = model.framework - ref_model.framework
current_treatment_name = f"{self.treatment_levels[i]} vs {self.treatment_levels[i_ref_lvl]}"

# update sensitivity elements with sharper bounds
current_sensitivity_dict = self._compute_causal_contrast_sensitivity_dict(model=model, ref_model=ref_model)
current_framework._check_and_set_sensitivity_elements(current_sensitivity_dict)

all_acc_frameworks += [current_framework]
all_treatment_names += [current_treatment_name]

acc = concat(all_acc_frameworks)
acc.treatment_names = all_treatment_names
Expand All @@ -768,6 +776,38 @@ def _fit_model(self, i_level, n_jobs_cv=None, store_predictions=True, store_mode
)
return model

def _compute_causal_contrast_sensitivity_dict(self, model, ref_model):
# reshape sensitivity elements to (1 or n_obs, n_coefs, n_rep)
model_sigma2 = np.transpose(model.sensitivity_elements["sigma2"], (0, 2, 1))
model_nu2 = np.transpose(model.sensitivity_elements["nu2"], (0, 2, 1))
model_psi_sigma2 = np.transpose(model.sensitivity_elements["psi_sigma2"], (0, 2, 1))
model_psi_nu2 = np.transpose(model.sensitivity_elements["psi_nu2"], (0, 2, 1))

ref_model_sigma2 = np.transpose(ref_model.sensitivity_elements["sigma2"], (0, 2, 1))
ref_model_nu2 = np.transpose(ref_model.sensitivity_elements["nu2"], (0, 2, 1))
ref_model_psi_sigma2 = np.transpose(ref_model.sensitivity_elements["psi_sigma2"], (0, 2, 1))
ref_model_psi_nu2 = np.transpose(ref_model.sensitivity_elements["psi_nu2"], (0, 2, 1))

combined_sensitivity_dict = {
"sigma2": (model_sigma2 + ref_model_sigma2) / 2,
"nu2": model_nu2 + ref_model_nu2,
"psi_sigma2": (model_psi_sigma2 + ref_model_psi_sigma2) / 2,
"psi_nu2": model_psi_nu2 + ref_model_psi_nu2,
}

max_bias, psi_max_bias = _compute_sensitivity_bias(**combined_sensitivity_dict)

new_sensitvitiy_dict = {
"sensitivity_elements": {
"max_bias": max_bias,
"psi_max_bias": psi_max_bias,
"sigma2": combined_sensitivity_dict["sigma2"],
"nu2": combined_sensitivity_dict["nu2"],
}
}

return new_sensitvitiy_dict

def _check_treatment_levels(self, treatment_levels):
is_iterable = isinstance(treatment_levels, Iterable)
if not is_iterable:
Expand Down Expand Up @@ -803,7 +843,7 @@ def _check_external_predictions(self, external_predictions):
+ f"Passed keys: {set(external_predictions.keys())}."
)

expected_learner_keys = ["ml_g0", "ml_g1", "ml_m"]
expected_learner_keys = ["ml_g_d_lvl0", "ml_g_d_lvl1", "ml_m"]
for key, value in external_predictions.items():
if not isinstance(value, dict):
raise TypeError(
Expand All @@ -821,12 +861,12 @@ def _rename_external_predictions(self, external_predictions):
d_col = self._dml_data.d_cols[0]
ext_pred_dict = {treatment_level: {d_col: {}} for treatment_level in self.treatment_levels}
for treatment_level in self.treatment_levels:
if "ml_g1" in external_predictions[treatment_level]:
ext_pred_dict[treatment_level][d_col]["ml_g1"] = external_predictions[treatment_level]["ml_g1"]
if "ml_g_d_lvl1" in external_predictions[treatment_level]:
ext_pred_dict[treatment_level][d_col]["ml_g_d_lvl1"] = external_predictions[treatment_level]["ml_g_d_lvl1"]
if "ml_m" in external_predictions[treatment_level]:
ext_pred_dict[treatment_level][d_col]["ml_m"] = external_predictions[treatment_level]["ml_m"]
if "ml_g0" in external_predictions[treatment_level]:
ext_pred_dict[treatment_level][d_col]["ml_g0"] = external_predictions[treatment_level]["ml_g0"]
if "ml_g_d_lvl0" in external_predictions[treatment_level]:
ext_pred_dict[treatment_level][d_col]["ml_g_d_lvl0"] = external_predictions[treatment_level]["ml_g_d_lvl0"]

return ext_pred_dict

Expand Down
12 changes: 7 additions & 5 deletions doubleml/irm/tests/_utils_apo_manual.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,12 +87,13 @@ def fit_nuisance_apo(
):
ml_g0 = clone(learner_g)
ml_g1 = clone(learner_g)
dx = np.column_stack((d, x))

train_cond0 = np.where(treated == 0)[0]
if is_classifier(learner_g):
g_hat0_list = fit_predict_proba(y, x, ml_g0, g0_params, smpls, train_cond=train_cond0)
g_hat0_list = fit_predict_proba(y, dx, ml_g0, g0_params, smpls, train_cond=train_cond0)
else:
g_hat0_list = fit_predict(y, x, ml_g0, g0_params, smpls, train_cond=train_cond0)
g_hat0_list = fit_predict(y, dx, ml_g0, g0_params, smpls, train_cond=train_cond0)

train_cond1 = np.where(treated == 1)[0]
if is_classifier(learner_g):
Expand Down Expand Up @@ -223,8 +224,8 @@ def fit_sensitivity_elements_apo(y, d, treatment_level, all_coef, predictions, s

for i_rep in range(n_rep):
m_hat = predictions["ml_m"][:, i_rep, 0]
g_hat0 = predictions["ml_g0"][:, i_rep, 0]
g_hat1 = predictions["ml_g1"][:, i_rep, 0]
g_hat0 = predictions["ml_g_d_lvl0"][:, i_rep, 0]
g_hat1 = predictions["ml_g_d_lvl1"][:, i_rep, 0]

weights = np.ones_like(d)
weights_bar = np.ones_like(d)
Expand All @@ -246,8 +247,9 @@ def fit_sensitivity_elements_apo(y, d, treatment_level, all_coef, predictions, s


def tune_nuisance_apo(y, x, d, treatment_level, ml_g, ml_m, smpls, score, n_folds_tune, param_grid_g, param_grid_m):
dx = np.column_stack((d, x))
train_cond0 = np.where(d != treatment_level)[0]
g0_tune_res = tune_grid_search(y, x, ml_g, smpls, param_grid_g, n_folds_tune, train_cond=train_cond0)
g0_tune_res = tune_grid_search(y, dx, ml_g, smpls, param_grid_g, n_folds_tune, train_cond=train_cond0)

train_cond1 = np.where(d == treatment_level)[0]
g1_tune_res = tune_grid_search(y, x, ml_g, smpls, param_grid_g, n_folds_tune, train_cond=train_cond1)
Expand Down
6 changes: 3 additions & 3 deletions doubleml/irm/tests/test_apo.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ def treatment_level(request):


@pytest.fixture(scope="module")
def dml_apo_fixture(generate_data_irm, learner, normalize_ipw, trimming_threshold, treatment_level):
def dml_apo_fixture(learner, normalize_ipw, trimming_threshold, treatment_level):
boot_methods = ["normal"]
n_folds = 2
n_rep_boot = 499
Expand Down Expand Up @@ -116,8 +116,8 @@ def dml_apo_fixture(generate_data_irm, learner, normalize_ipw, trimming_threshol

prediction_dict = {
"d": {
"ml_g0": dml_obj.predictions["ml_g0"].reshape(-1, 1),
"ml_g1": dml_obj.predictions["ml_g1"].reshape(-1, 1),
"ml_g_d_lvl0": dml_obj.predictions["ml_g_d_lvl0"].reshape(-1, 1),
"ml_g_d_lvl1": dml_obj.predictions["ml_g_d_lvl1"].reshape(-1, 1),
"ml_m": dml_obj.predictions["ml_m"].reshape(-1, 1),
}
}
Expand Down
Loading