Skip to content

Commit 870c12c

Browse files
committed
remove ml_m from did_cs experimental
1 parent 36cf51b commit 870c12c

File tree

3 files changed

+10
-4
lines changed

3 files changed

+10
-4
lines changed

doubleml/did/did_cs.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -162,8 +162,13 @@ def trimming_threshold(self):
162162
return self._trimming_threshold
163163

164164
def _initialize_ml_nuisance_params(self):
165-
valid_learner = ['ml_g_d0_t0', 'ml_g_d0_t1',
166-
'ml_g_d1_t0', 'ml_g_d1_t1', 'ml_m']
165+
if self.score == 'observational':
166+
valid_learner = ['ml_g_d0_t0', 'ml_g_d0_t1',
167+
'ml_g_d1_t0', 'ml_g_d1_t1', 'ml_m']
168+
else:
169+
assert self.score == 'experimental'
170+
valid_learner = ['ml_g_d0_t0', 'ml_g_d0_t1',
171+
'ml_g_d1_t0', 'ml_g_d1_t1']
167172
self._params = {learner: {key: [None] * self.n_rep for key in self._dml_data.d_cols}
168173
for learner in valid_learner}
169174

doubleml/did/tests/_utils_did_cs_manual.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -285,7 +285,6 @@ def fit_sensitivity_elements_did_cs(y, d, t, all_coef, predictions, score, in_sa
285285

286286
for i_rep in range(n_rep):
287287

288-
m_hat = predictions['ml_m'][:, i_rep, 0]
289288
g_hat_d0_t0 = predictions['ml_g_d0_t0'][:, i_rep, 0]
290289
g_hat_d0_t1 = predictions['ml_g_d0_t1'][:, i_rep, 0]
291290
g_hat_d1_t0 = predictions['ml_g_d1_t0'][:, i_rep, 0]
@@ -305,6 +304,7 @@ def fit_sensitivity_elements_did_cs(y, d, t, all_coef, predictions, score, in_sa
305304
p_hat = np.mean(d)
306305
lambda_hat = np.mean(t)
307306
if score == 'observational':
307+
m_hat = predictions['ml_m'][:, i_rep, 0]
308308
propensity_weight_d0 = np.divide(m_hat, 1.0-m_hat)
309309
if in_sample_normalization:
310310
weight_d0t1 = np.multiply(d0t1, propensity_weight_d0)

doubleml/did/tests/test_did_cs_external_predictions.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,8 @@ def doubleml_didcs_fixture(did_score, n_rep):
3939
ext_predictions["d"]["ml_g_d0_t1"] = dml_did_cs.predictions["ml_g_d0_t1"][:, :, 0]
4040
ext_predictions["d"]["ml_g_d1_t0"] = dml_did_cs.predictions["ml_g_d1_t0"][:, :, 0]
4141
ext_predictions["d"]["ml_g_d1_t1"] = dml_did_cs.predictions["ml_g_d1_t1"][:, :, 0]
42-
ext_predictions["d"]["ml_m"] = dml_did_cs.predictions["ml_m"][:, :, 0]
42+
if did_score == "observational":
43+
ext_predictions["d"]["ml_m"] = dml_did_cs.predictions["ml_m"][:, :, 0]
4344

4445
dml_did_cs_ext = DoubleMLDIDCS(ml_g=DMLDummyRegressor(), ml_m=DMLDummyClassifier(), **kwargs)
4546
dml_did_cs_ext.set_sample_splitting(all_smpls)

0 commit comments

Comments
 (0)