Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
41 commits
Select commit Hold shift + click to select a range
4969a4d
add first elements for senstivity in framework class
SvenKlaassen Jun 7, 2024
d4646bb
Merge branch 'main' into s-add-sensitivity-framework
SvenKlaassen Jun 11, 2024
3add93e
add sensitivity elements to framework
SvenKlaassen Jun 16, 2024
d95bdef
extend exception tests
SvenKlaassen Jun 16, 2024
b308d6e
add tests for shapes
SvenKlaassen Jun 17, 2024
ebc8b39
add riesz representer to sensitivity_elements
SvenKlaassen Jun 17, 2024
33b68b3
add sensitivity calculation for framework operations
SvenKlaassen Jun 17, 2024
e6db5b4
add riesz_rep to framework
SvenKlaassen Jun 17, 2024
40d8a6f
add returntype tests for senstivity_elements
SvenKlaassen Jun 17, 2024
26d65a9
fix nu2 dimension
SvenKlaassen Jun 17, 2024
2254346
add exception tests for riesz rep in framework
SvenKlaassen Jun 17, 2024
6ef2c96
extend sensitivity framework tests
SvenKlaassen Jun 17, 2024
2443ccd
add senstivity_framework to concat
SvenKlaassen Jun 17, 2024
747e9cd
add first _calc_sensitivity_analysis()
SvenKlaassen Jun 17, 2024
92b8699
add basic sensitivity calculations
SvenKlaassen Jun 17, 2024
6c19e7d
update sensitivity_initialisation
SvenKlaassen Jun 17, 2024
a659b67
add variance check
SvenKlaassen Jun 17, 2024
ceb9117
further update framework
SvenKlaassen Jun 17, 2024
86e2ef0
move sensitivity test to extra file
SvenKlaassen Jun 18, 2024
1547913
raise error for concat clustering
SvenKlaassen Jun 18, 2024
9e63440
add input check for cluster data
SvenKlaassen Jun 18, 2024
c36dcf0
further input check for cluster data
SvenKlaassen Jun 18, 2024
47d9b76
update set sensitivity elements
SvenKlaassen Jun 18, 2024
ea7773a
add cluster dict to doubleml class
SvenKlaassen Jun 18, 2024
de4d964
fix exception test
SvenKlaassen Jun 18, 2024
b738cda
formatting
SvenKlaassen Jun 18, 2024
2a33717
add cluster compatiblitly for operations
SvenKlaassen Jun 18, 2024
b887272
Update double_ml_framework.py
SvenKlaassen Jun 18, 2024
a9529f0
add simple dataset
SvenKlaassen Jun 18, 2024
2fd1434
add rv and senstivity analysis
SvenKlaassen Jun 18, 2024
78deba6
add hypothesis exception test
SvenKlaassen Jun 18, 2024
d892bb5
move sensitivity_analysis() to framework
SvenKlaassen Jun 18, 2024
e492c6b
move _calc_sensitivity_analysis and _calc_robustness_value to framework
SvenKlaassen Jun 18, 2024
beac0c9
fix format
SvenKlaassen Jun 18, 2024
32fa7b8
add plot to framework
SvenKlaassen Jun 18, 2024
0db982b
fix format
SvenKlaassen Jun 18, 2024
2feb34a
move sensivitiy_plot to framework
SvenKlaassen Jun 18, 2024
09ea568
update sensitivity_plot docstring
SvenKlaassen Jun 18, 2024
a312ea4
Merge branch 'main' into s-add-sensitivity-framework
SvenKlaassen Jul 27, 2024
d23243e
Merge branch 'main' into s-add-sensitivity-framework
SvenKlaassen Aug 9, 2024
399decf
fix benchmark in sensitivity plot
SvenKlaassen Aug 9, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion doubleml/did/did.py
Original file line number Diff line number Diff line change
Expand Up @@ -343,7 +343,9 @@ def _sensitivity_element_est(self, preds):
element_dict = {'sigma2': sigma2,
'nu2': nu2,
'psi_sigma2': psi_sigma2,
'psi_nu2': psi_nu2}
'psi_nu2': psi_nu2,
'riesz_rep': rr,
}
return element_dict

def _nuisance_tuning(self, smpls, param_grids, scoring_methods, n_folds_tune, n_jobs_cv,
Expand Down
4 changes: 3 additions & 1 deletion doubleml/did/did_cs.py
Original file line number Diff line number Diff line change
Expand Up @@ -476,7 +476,9 @@ def _sensitivity_element_est(self, preds):
element_dict = {'sigma2': sigma2,
'nu2': nu2,
'psi_sigma2': psi_sigma2,
'psi_nu2': psi_nu2}
'psi_nu2': psi_nu2,
'riesz_rep': rr,
}
return element_dict

def _nuisance_tuning(self, smpls, param_grids, scoring_methods, n_folds_tune, n_jobs_cv,
Expand Down
330 changes: 86 additions & 244 deletions doubleml/double_ml.py

Large diffs are not rendered by default.

574 changes: 533 additions & 41 deletions doubleml/double_ml_framework.py

Large diffs are not rendered by default.

4 changes: 3 additions & 1 deletion doubleml/irm/irm.py
Original file line number Diff line number Diff line change
Expand Up @@ -400,7 +400,9 @@ def _sensitivity_element_est(self, preds):
element_dict = {'sigma2': sigma2,
'nu2': nu2,
'psi_sigma2': psi_sigma2,
'psi_nu2': psi_nu2}
'psi_nu2': psi_nu2,
'riesz_rep': rr,
}
return element_dict

def _nuisance_tuning(self, smpls, param_grids, scoring_methods, n_folds_tune, n_jobs_cv,
Expand Down
10 changes: 7 additions & 3 deletions doubleml/plm/plr.py
Original file line number Diff line number Diff line change
Expand Up @@ -279,13 +279,17 @@ def _sensitivity_element_est(self, preds):
sigma2 = np.mean(sigma2_score_element)
psi_sigma2 = sigma2_score_element - sigma2

nu2 = np.divide(1.0, np.mean(np.square(d - m_hat)))
psi_nu2 = nu2 - np.multiply(np.square(d-m_hat), np.square(nu2))
treatment_residual = d - m_hat
nu2 = np.divide(1.0, np.mean(np.square(treatment_residual)))
psi_nu2 = nu2 - np.multiply(np.square(treatment_residual), np.square(nu2))
rr = np.multiply(treatment_residual, nu2)

element_dict = {'sigma2': sigma2,
'nu2': nu2,
'psi_sigma2': psi_sigma2,
'psi_nu2': psi_nu2}
'psi_nu2': psi_nu2,
'riesz_rep': rr,
}
return element_dict

def _nuisance_tuning(self, smpls, param_grids, scoring_methods, n_folds_tune, n_jobs_cv,
Expand Down
8 changes: 6 additions & 2 deletions doubleml/tests/_utils_doubleml_sensitivity_manual.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,12 @@ def doubleml_sensitivity_manual(sensitivity_elements, all_coefs, psi, psi_deriv,
theta_upper, sigma_upper = _aggregate_coefs_and_ses(all_theta_upper, all_sigma_upper, var_scaling_factor)

quant = norm.ppf(level)
ci_lower = theta_lower - np.multiply(quant, sigma_lower)
ci_upper = theta_upper + np.multiply(quant, sigma_upper)

all_ci_lower = all_theta_lower - np.multiply(quant, all_sigma_lower)
all_ci_upper = all_theta_upper + np.multiply(quant, all_sigma_upper)

ci_lower = np.median(all_ci_lower, axis=1)
ci_upper = np.median(all_ci_upper, axis=1)

theta_dict = {'lower': theta_lower,
'upper': theta_upper}
Expand Down
25 changes: 25 additions & 0 deletions doubleml/tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@
from doubleml.datasets import make_plr_turrell2018, make_irm_data, \
make_pliv_CHS2015

from doubleml import DoubleMLData


def _g(x):
return np.power(np.sin(x), 2)
Expand All @@ -22,6 +24,29 @@ def _m2(x):
return np.power(x, 2)


@pytest.fixture(scope='session',
params=[(500, 5)])
def generate_data_simple(request):
n_p = request.param
np.random.seed(1111)
# setting parameters
n = n_p[0]
p = n_p[1]
theta = 1.0

# generating data
D1 = 1.0 * (np.random.uniform(size=n) > 0.5)
D2 = 1.0 * (np.random.uniform(size=n) > 0.5)
X = np.random.normal(size=(n, p))
Y = theta * D1 + np.dot(X, np.ones(p)) + np.random.normal(size=n)
df = pd.DataFrame(np.column_stack((X, Y, D1, D2)),
columns=[f'X{i + 1}' for i in np.arange(p)] + ['Y', 'D1', 'D2'])
data_d1 = DoubleMLData(df, 'Y', 'D1')
data_d2 = DoubleMLData(df, 'Y', 'D2')

return data_d1, data_d2


@pytest.fixture(scope='session',
params=[(500, 10),
(1000, 20),
Expand Down
52 changes: 4 additions & 48 deletions doubleml/tests/test_exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -1008,10 +1008,11 @@ def test_doubleml_sensitivity_not_yet_implemented():

dml_pliv = DoubleMLPLIV(dml_data_pliv, ml_g, ml_m, ml_r)
dml_pliv.fit()
msg = "Sensitivity analysis not yet implemented for DoubleMLPLIV."
msg = 'Sensitivity analysis is not implemented for this model.'
with pytest.raises(NotImplementedError, match=msg):
_ = dml_pliv.sensitivity_analysis()

msg = 'Sensitivity analysis not yet implemented for DoubleMLPLIV.'
with pytest.raises(NotImplementedError, match=msg):
_ = dml_pliv.sensitivity_benchmark(benchmarking_set=["X1"])

Expand All @@ -1025,77 +1026,45 @@ def test_doubleml_sensitivity_inputs():
msg = "cf_y must be of float type. 1 of type <class 'int'> was passed."
with pytest.raises(TypeError, match=msg):
_ = dml_irm.sensitivity_analysis(cf_y=1)
with pytest.raises(TypeError, match=msg):
_ = dml_irm._calc_sensitivity_analysis(cf_y=1, cf_d=0.03, rho=1.0, level=0.95)

msg = r'cf_y must be in \[0,1\). 1.0 was passed.'
with pytest.raises(ValueError, match=msg):
_ = dml_irm.sensitivity_analysis(cf_y=1.0)
with pytest.raises(ValueError, match=msg):
_ = dml_irm._calc_sensitivity_analysis(cf_y=1.0, cf_d=0.03, rho=1.0, level=0.95)

# test cf_d
msg = "cf_d must be of float type. 1 of type <class 'int'> was passed."
with pytest.raises(TypeError, match=msg):
_ = dml_irm.sensitivity_analysis(cf_y=0.1, cf_d=1)
with pytest.raises(TypeError, match=msg):
_ = dml_irm._calc_sensitivity_analysis(cf_y=0.1, cf_d=1, rho=1.0, level=0.95)

msg = r'cf_d must be in \[0,1\). 1.0 was passed.'
with pytest.raises(ValueError, match=msg):
_ = dml_irm.sensitivity_analysis(cf_y=0.1, cf_d=1.0)
with pytest.raises(ValueError, match=msg):
_ = dml_irm._calc_sensitivity_analysis(cf_y=0.1, cf_d=1.0, rho=1.0, level=0.95)

# test rho
msg = "rho must be of float type. 1 of type <class 'int'> was passed."
with pytest.raises(TypeError, match=msg):
_ = dml_irm.sensitivity_analysis(cf_y=0.1, cf_d=0.15, rho=1)
with pytest.raises(TypeError, match=msg):
_ = dml_irm._calc_sensitivity_analysis(cf_y=0.1, cf_d=0.15, rho=1, level=0.95)
with pytest.raises(TypeError, match=msg):
_ = dml_irm._calc_robustness_value(rho=1, null_hypothesis=0.0, level=0.95, idx_treatment=0)

msg = "rho must be of float type. 1 of type <class 'str'> was passed."
with pytest.raises(TypeError, match=msg):
_ = dml_irm.sensitivity_analysis(cf_y=0.1, cf_d=0.15, rho="1")
with pytest.raises(TypeError, match=msg):
_ = dml_irm._calc_sensitivity_analysis(cf_y=0.1, cf_d=0.15, rho="1", level=0.95)
with pytest.raises(TypeError, match=msg):
_ = dml_irm._calc_robustness_value(rho="1", null_hypothesis=0.0, level=0.95, idx_treatment=0)

msg = r'The absolute value of rho must be in \[0,1\]. 1.1 was passed.'
with pytest.raises(ValueError, match=msg):
_ = dml_irm.sensitivity_analysis(cf_y=0.1, cf_d=0.15, rho=1.1)
with pytest.raises(ValueError, match=msg):
_ = dml_irm._calc_sensitivity_analysis(cf_y=0.1, cf_d=0.15, rho=1.1, level=0.95)
with pytest.raises(ValueError, match=msg):
_ = dml_irm._calc_robustness_value(rho=1.1, null_hypothesis=0.0, level=0.95, idx_treatment=0)

# test level
msg = "The confidence level must be of float type. 1 of type <class 'int'> was passed."
with pytest.raises(TypeError, match=msg):
_ = dml_irm.sensitivity_analysis(cf_y=0.1, cf_d=0.15, rho=1.0, level=1)
with pytest.raises(TypeError, match=msg):
_ = dml_irm._calc_sensitivity_analysis(cf_y=0.1, cf_d=0.15, rho=1.0, level=1)
with pytest.raises(TypeError, match=msg):
_ = dml_irm._calc_robustness_value(rho=1.0, level=1, null_hypothesis=0.0, idx_treatment=0)

msg = r'The confidence level must be in \(0,1\). 1.0 was passed.'
with pytest.raises(ValueError, match=msg):
_ = dml_irm.sensitivity_analysis(cf_y=0.1, cf_d=0.15, rho=1.0, level=1.0)
with pytest.raises(ValueError, match=msg):
_ = dml_irm._calc_sensitivity_analysis(cf_y=0.1, cf_d=0.15, rho=1.0, level=1.0)
with pytest.raises(ValueError, match=msg):
_ = dml_irm._calc_robustness_value(rho=1.0, level=1.0, null_hypothesis=0.0, idx_treatment=0)

msg = r'The confidence level must be in \(0,1\). 0.0 was passed.'
with pytest.raises(ValueError, match=msg):
_ = dml_irm.sensitivity_analysis(cf_y=0.1, cf_d=0.15, rho=1.0, level=0.0)
with pytest.raises(ValueError, match=msg):
_ = dml_irm._calc_sensitivity_analysis(cf_y=0.1, cf_d=0.15, rho=1.0, level=0.0)
with pytest.raises(ValueError, match=msg):
_ = dml_irm._calc_robustness_value(rho=1.0, level=0.0, null_hypothesis=0.0, idx_treatment=0)

# test null_hypothesis
msg = "null_hypothesis has to be of type float or np.ndarry. 1 of type <class 'int'> was passed."
Expand All @@ -1104,30 +1073,18 @@ def test_doubleml_sensitivity_inputs():
msg = r"null_hypothesis is numpy.ndarray but does not have the required shape \(1,\). Array of shape \(2,\) was passed."
with pytest.raises(ValueError, match=msg):
_ = dml_irm.sensitivity_analysis(null_hypothesis=np.array([1, 2]))
msg = "null_hypothesis must be of float type. 1 of type <class 'int'> was passed."
with pytest.raises(TypeError, match=msg):
_ = dml_irm._calc_robustness_value(null_hypothesis=1, level=0.95, rho=1.0, idx_treatment=0)
msg = r"null_hypothesis must be of float type. \[1\] of type <class 'numpy.ndarray'> was passed."
with pytest.raises(TypeError, match=msg):
_ = dml_irm._calc_robustness_value(null_hypothesis=np.array([1]), level=0.95, rho=1.0, idx_treatment=0)

# test idx_treatment
dml_irm.sensitivity_analysis()
msg = "idx_treatment must be an integer. 0.0 of type <class 'float'> was passed."
with pytest.raises(TypeError, match=msg):
_ = dml_irm._calc_robustness_value(idx_treatment=0.0, null_hypothesis=0.0, level=0.95, rho=1.0)
with pytest.raises(TypeError, match=msg):
_ = dml_irm.sensitivity_plot(idx_treatment=0.0)

msg = "idx_treatment must be larger or equal to 0. -1 was passed."
with pytest.raises(ValueError, match=msg):
_ = dml_irm._calc_robustness_value(idx_treatment=-1, null_hypothesis=0.0, level=0.95, rho=1.0)
with pytest.raises(ValueError, match=msg):
_ = dml_irm.sensitivity_plot(idx_treatment=-1)

msg = "idx_treatment must be smaller or equal to 0. 1 was passed."
with pytest.raises(ValueError, match=msg):
_ = dml_irm._calc_robustness_value(idx_treatment=1, null_hypothesis=0.0, level=0.95, rho=1.0)
with pytest.raises(ValueError, match=msg):
_ = dml_irm.sensitivity_plot(idx_treatment=1)

Expand All @@ -1142,7 +1099,7 @@ def test_doubleml_sensitivity_inputs():
_ = dml_irm._set_sensitivity_elements(sensitivity_elements=sensitivity_elements, i_rep=0, i_treat=0)

# test variances
sensitivity_elements = dict({'sigma2': 1.0, 'nu2': -2.4, 'psi_sigma2': 1.0, 'psi_nu2': 1.0})
sensitivity_elements = dict({'sigma2': 1.0, 'nu2': -2.4, 'psi_sigma2': 1.0, 'psi_nu2': 1.0, 'riesz_rep': 1.0})
_ = dml_irm._set_sensitivity_elements(sensitivity_elements=sensitivity_elements, i_rep=0, i_treat=0)
msg = ('sensitivity_elements sigma2 and nu2 have to be positive. '
r'Got sigma2 \[\[\[1.\]\]\] and nu2 \[\[\[-2.4\]\]\]. '
Expand Down Expand Up @@ -1176,8 +1133,7 @@ def test_doubleml_sensitivity_plot_input():
dml_irm = DoubleMLIRM(dml_data_irm, Lasso(), LogisticRegression(), trimming_threshold=0.1)
dml_irm.fit()

msg = (r'Apply sensitivity_analysis\(\) to include senario in sensitivity_plot. '
'The values of rho and the level are used for the scenario.')
msg = (r'Apply sensitivity_analysis\(\) to include senario in sensitivity_plot. ')
with pytest.raises(ValueError, match=msg):
_ = dml_irm.sensitivity_plot()

Expand Down
26 changes: 16 additions & 10 deletions doubleml/tests/test_framework.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,7 +163,7 @@ def dml_framework_from_doubleml_fixture(n_rep):
ml_g = LinearRegression()
ml_m = LogisticRegression()

dml_irm_obj = DoubleMLIRM(dml_data, ml_g, ml_m)
dml_irm_obj = DoubleMLIRM(dml_data, ml_g, ml_m, n_rep=n_rep)
dml_irm_obj.fit()
dml_framework_obj = dml_irm_obj.construct_framework()

Expand All @@ -179,7 +179,7 @@ def dml_framework_from_doubleml_fixture(n_rep):

# substract objects
dml_data_2 = make_irm_data()
dml_irm_obj_2 = DoubleMLIRM(dml_data_2, ml_g, ml_m)
dml_irm_obj_2 = DoubleMLIRM(dml_data_2, ml_g, ml_m, n_rep=n_rep)
dml_irm_obj_2.fit()
dml_framework_obj_2 = dml_irm_obj_2.construct_framework()

Expand Down Expand Up @@ -218,6 +218,7 @@ def dml_framework_from_doubleml_fixture(n_rep):
'ci_joint_sub_obj': ci_joint_sub_obj,
'ci_joint_mul_obj': ci_joint_mul_obj,
'ci_joint_concat': ci_joint_concat,
'n_rep': n_rep,
}
return result_dict

Expand Down Expand Up @@ -257,14 +258,19 @@ def test_dml_framework_from_doubleml_se(dml_framework_from_doubleml_fixture):
dml_framework_from_doubleml_fixture['dml_framework_obj_add_obj'].all_ses,
2*dml_framework_from_doubleml_fixture['dml_obj'].all_se
)
scaling = np.array([dml_framework_from_doubleml_fixture['dml_obj']._var_scaling_factors]).reshape(-1, 1)
sub_var = np.mean(
np.square(dml_framework_from_doubleml_fixture['dml_obj'].psi - dml_framework_from_doubleml_fixture['dml_obj_2'].psi),
axis=0)
assert np.allclose(
dml_framework_from_doubleml_fixture['dml_framework_obj_sub_obj'].all_ses,
np.sqrt(sub_var / scaling)
)

if dml_framework_from_doubleml_fixture['n_rep'] == 1:
# formula only valid for n_rep = 1
scaling = np.array([dml_framework_from_doubleml_fixture['dml_obj']._var_scaling_factors]).reshape(-1, 1)
sub_var = np.mean(
np.square(dml_framework_from_doubleml_fixture['dml_obj'].psi
- dml_framework_from_doubleml_fixture['dml_obj_2'].psi),
axis=0)
assert np.allclose(
dml_framework_from_doubleml_fixture['dml_framework_obj_sub_obj'].all_ses,
np.sqrt(sub_var / scaling)
)

assert np.allclose(
dml_framework_from_doubleml_fixture['dml_framework_obj_mul_obj'].all_ses,
2*dml_framework_from_doubleml_fixture['dml_obj'].all_se
Expand Down
Loading