Skip to content

Commit 296114d

Browse files
authored
Merge pull request #254 from DoubleML/s-add-logloss
Add Log Loss to nuisance evaluation
2 parents 348f0ca + 870c12c commit 296114d

File tree

10 files changed

+133
-82
lines changed

10 files changed

+133
-82
lines changed

doubleml/did/did.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -161,7 +161,11 @@ def trimming_threshold(self):
161161
return self._trimming_threshold
162162

163163
def _initialize_ml_nuisance_params(self):
164-
valid_learner = ['ml_g0', 'ml_g1', 'ml_m']
164+
if self.score == 'observational':
165+
valid_learner = ['ml_g0', 'ml_g1', 'ml_m']
166+
else:
167+
assert self.score == 'experimental'
168+
valid_learner = ['ml_g0', 'ml_g1']
165169
self._params = {learner: {key: [None] * self.n_rep for key in self._dml_data.d_cols}
166170
for learner in valid_learner}
167171

doubleml/did/did_cs.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -162,8 +162,13 @@ def trimming_threshold(self):
162162
return self._trimming_threshold
163163

164164
def _initialize_ml_nuisance_params(self):
165-
valid_learner = ['ml_g_d0_t0', 'ml_g_d0_t1',
166-
'ml_g_d1_t0', 'ml_g_d1_t1', 'ml_m']
165+
if self.score == 'observational':
166+
valid_learner = ['ml_g_d0_t0', 'ml_g_d0_t1',
167+
'ml_g_d1_t0', 'ml_g_d1_t1', 'ml_m']
168+
else:
169+
assert self.score == 'experimental'
170+
valid_learner = ['ml_g_d0_t0', 'ml_g_d0_t1',
171+
'ml_g_d1_t0', 'ml_g_d1_t1']
167172
self._params = {learner: {key: [None] * self.n_rep for key in self._dml_data.d_cols}
168173
for learner in valid_learner}
169174

doubleml/did/tests/_utils_did_cs_manual.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -285,7 +285,6 @@ def fit_sensitivity_elements_did_cs(y, d, t, all_coef, predictions, score, in_sa
285285

286286
for i_rep in range(n_rep):
287287

288-
m_hat = predictions['ml_m'][:, i_rep, 0]
289288
g_hat_d0_t0 = predictions['ml_g_d0_t0'][:, i_rep, 0]
290289
g_hat_d0_t1 = predictions['ml_g_d0_t1'][:, i_rep, 0]
291290
g_hat_d1_t0 = predictions['ml_g_d1_t0'][:, i_rep, 0]
@@ -305,6 +304,7 @@ def fit_sensitivity_elements_did_cs(y, d, t, all_coef, predictions, score, in_sa
305304
p_hat = np.mean(d)
306305
lambda_hat = np.mean(t)
307306
if score == 'observational':
307+
m_hat = predictions['ml_m'][:, i_rep, 0]
308308
propensity_weight_d0 = np.divide(m_hat, 1.0-m_hat)
309309
if in_sample_normalization:
310310
weight_d0t1 = np.multiply(d0t1, propensity_weight_d0)

doubleml/did/tests/_utils_did_manual.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -220,7 +220,6 @@ def fit_sensitivity_elements_did(y, d, all_coef, predictions, score, in_sample_n
220220

221221
for i_rep in range(n_rep):
222222

223-
m_hat = predictions['ml_m'][:, i_rep, 0]
224223
g_hat0 = predictions['ml_g0'][:, i_rep, 0]
225224
g_hat1 = predictions['ml_g1'][:, i_rep, 0]
226225

@@ -229,6 +228,7 @@ def fit_sensitivity_elements_did(y, d, all_coef, predictions, score, in_sample_n
229228
psi_sigma2[:, i_rep, 0] = sigma2_score_element - sigma2[0, i_rep, 0]
230229

231230
if score == 'observational':
231+
m_hat = predictions['ml_m'][:, i_rep, 0]
232232
propensity_weight_d0 = np.divide(m_hat, 1.0-m_hat)
233233
if in_sample_normalization:
234234
m_alpha_1 = np.divide(d, np.mean(d))

doubleml/did/tests/test_did_cs_external_predictions.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,8 @@ def doubleml_didcs_fixture(did_score, n_rep):
3939
ext_predictions["d"]["ml_g_d0_t1"] = dml_did_cs.predictions["ml_g_d0_t1"][:, :, 0]
4040
ext_predictions["d"]["ml_g_d1_t0"] = dml_did_cs.predictions["ml_g_d1_t0"][:, :, 0]
4141
ext_predictions["d"]["ml_g_d1_t1"] = dml_did_cs.predictions["ml_g_d1_t1"][:, :, 0]
42-
ext_predictions["d"]["ml_m"] = dml_did_cs.predictions["ml_m"][:, :, 0]
42+
if did_score == "observational":
43+
ext_predictions["d"]["ml_m"] = dml_did_cs.predictions["ml_m"][:, :, 0]
4344

4445
dml_did_cs_ext = DoubleMLDIDCS(ml_g=DMLDummyRegressor(), ml_m=DMLDummyClassifier(), **kwargs)
4546
dml_did_cs_ext.set_sample_splitting(all_smpls)

doubleml/did/tests/test_did_external_predictions.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,8 @@ def doubleml_did_fixture(did_score, n_rep):
3636

3737
ext_predictions["d"]["ml_g0"] = dml_did.predictions["ml_g0"][:, :, 0]
3838
ext_predictions["d"]["ml_g1"] = dml_did.predictions["ml_g1"][:, :, 0]
39-
ext_predictions["d"]["ml_m"] = dml_did.predictions["ml_m"][:, :, 0]
39+
if did_score == "observational":
40+
ext_predictions["d"]["ml_m"] = dml_did.predictions["ml_m"][:, :, 0]
4041

4142
dml_did_ext = DoubleMLDID(ml_g=DMLDummyRegressor(), ml_m=DMLDummyClassifier(), **kwargs)
4243
dml_did_ext.set_sample_splitting(all_smpls)

doubleml/double_ml.py

Lines changed: 49 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -50,11 +50,12 @@ def __init__(self,
5050
# initialize learners and parameters which are set model specific
5151
self._learner = None
5252
self._params = None
53+
self._is_classifier = {}
5354

5455
# initialize predictions and target to None which are only stored if method fit is called with store_predictions=True
5556
self._predictions = None
5657
self._nuisance_targets = None
57-
self._rmses = None
58+
self._nuisance_loss = None
5859

5960
# initialize models to None which are only stored if method fit is called with store_models=True
6061
self._models = None
@@ -119,10 +120,18 @@ def __str__(self):
119120
learner_info = ''
120121
for key, value in self.learner.items():
121122
learner_info += f'Learner {key}: {str(value)}\n'
122-
if self.rmses is not None:
123+
if self.nuisance_loss is not None:
123124
learner_info += 'Out-of-sample Performance:\n'
124-
for learner in self.params_names:
125-
learner_info += f'Learner {learner} RMSE: {self.rmses[learner]}\n'
125+
is_classifier = [value for value in self._is_classifier.values()]
126+
is_regressor = [not value for value in is_classifier]
127+
if any(is_regressor):
128+
learner_info += 'Regression:\n'
129+
for learner in [key for key, value in self._is_classifier.items() if value is False]:
130+
learner_info += f'Learner {learner} RMSE: {self.nuisance_loss[learner]}\n'
131+
if any(is_classifier):
132+
learner_info += 'Classification:\n'
133+
for learner in [key for key, value in self._is_classifier.items() if value is True]:
134+
learner_info += f'Learner {learner} Log Loss: {self.nuisance_loss[learner]}\n'
126135

127136
if self._is_cluster_data:
128137
resampling_info = f'No. folds per cluster: {self._n_folds_per_cluster}\n' \
@@ -234,11 +243,11 @@ def nuisance_targets(self):
234243
return self._nuisance_targets
235244

236245
@property
237-
def rmses(self):
246+
def nuisance_loss(self):
238247
"""
239-
The root-mean-squared-errors of the nuisance models.
248+
The losses of the nuisance models (root-mean-squared-errors or logloss).
240249
"""
241-
return self._rmses
250+
return self._nuisance_loss
242251

243252
@property
244253
def models(self):
@@ -915,8 +924,8 @@ def _check_fit(self, n_jobs_cv, store_predictions, external_predictions, store_m
915924
raise NotImplementedError(f"External predictions not implemented for {self.__class__.__name__}.")
916925

917926
def _initalize_fit(self, store_predictions, store_models):
918-
# initialize rmse arrays for nuisance functions evaluation
919-
self._initialize_rmses()
927+
# initialize loss arrays for nuisance functions evaluation
928+
self._initialize_nuisance_loss()
920929

921930
if store_predictions:
922931
self._initialize_predictions_and_targets()
@@ -942,8 +951,8 @@ def _fit_nuisance_and_score_elements(self, n_jobs_cv, store_predictions, externa
942951

943952
self._set_score_elements(score_elements, self._i_rep, self._i_treat)
944953

945-
# calculate rmses and store predictions and targets of the nuisance models
946-
self._calc_rmses(preds['predictions'], preds['targets'])
954+
# calculate nuisance losses and store predictions and targets of the nuisance models
955+
self._calc_nuisance_loss(preds['predictions'], preds['targets'])
947956
if store_predictions:
948957
self._store_predictions_and_targets(preds['predictions'], preds['targets'])
949958
if store_models:
@@ -1001,9 +1010,11 @@ def _initialize_predictions_and_targets(self):
10011010
self._nuisance_targets = {learner: np.full((self._dml_data.n_obs, self.n_rep, self._dml_data.n_coefs), np.nan)
10021011
for learner in self.params_names}
10031012

1004-
def _initialize_rmses(self):
1005-
self._rmses = {learner: np.full((self.n_rep, self._dml_data.n_coefs), np.nan)
1006-
for learner in self.params_names}
1013+
def _initialize_nuisance_loss(self):
1014+
self._nuisance_loss = {
1015+
learner: np.full((self.n_rep, self._dml_data.n_coefs), np.nan)
1016+
for learner in self.params_names
1017+
}
10071018

10081019
def _initialize_models(self):
10091020
self._models = {learner: {treat_var: [None] * self.n_rep for treat_var in self._dml_data.d_cols}
@@ -1014,13 +1025,33 @@ def _store_predictions_and_targets(self, preds, targets):
10141025
self._predictions[learner][:, self._i_rep, self._i_treat] = preds[learner]
10151026
self._nuisance_targets[learner][:, self._i_rep, self._i_treat] = targets[learner]
10161027

1017-
def _calc_rmses(self, preds, targets):
1028+
def _calc_nuisance_loss(self, preds, targets):
1029+
self._is_classifier = {key: False for key in self.params_names}
10181030
for learner in self.params_names:
1031+
# check if the learner is a classifier
1032+
learner_keys = [key for key in self._learner.keys() if key in learner]
1033+
assert len(learner_keys) == 1
1034+
self._is_classifier[learner] = self._check_learner(
1035+
self._learner[learner_keys[0]],
1036+
learner,
1037+
regressor=True, classifier=True
1038+
)
1039+
10191040
if targets[learner] is None:
1020-
self._rmses[learner][self._i_rep, self._i_treat] = np.nan
1041+
self._nuisance_loss[learner][self._i_rep, self._i_treat] = np.nan
10211042
else:
1022-
sq_error = np.power(targets[learner] - preds[learner], 2)
1023-
self._rmses[learner][self._i_rep, self._i_treat] = np.sqrt(np.nanmean(sq_error, axis=0))
1043+
learner_keys = [key for key in self._learner.keys() if key in learner]
1044+
assert len(learner_keys) == 1
1045+
1046+
if self._is_classifier[learner]:
1047+
predictions = np.clip(preds[learner], 1e-15, 1 - 1e-15)
1048+
logloss = targets[learner] * np.log(predictions) + (1 - targets[learner]) * np.log(1 - predictions)
1049+
loss = -np.nanmean(logloss, axis=0)
1050+
else:
1051+
sq_error = np.power(targets[learner] - preds[learner], 2)
1052+
loss = np.sqrt(np.nanmean(sq_error, axis=0))
1053+
1054+
self._nuisance_loss[learner][self._i_rep, self._i_treat] = loss
10241055

10251056
def _store_models(self, models):
10261057
for learner in self.params_names:

doubleml/tests/test_evaluate_learner.py

Lines changed: 15 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,8 @@
77
from sklearn.linear_model import LogisticRegression, LinearRegression
88
from sklearn.ensemble import RandomForestClassifier, RandomForestRegressor
99

10+
from doubleml.utils._estimation import _logloss
11+
1012

1113
np.random.seed(3141)
1214
data = make_irm_data(theta=0.5, n_obs=200, dim_x=5, return_type='DataFrame')
@@ -47,26 +49,27 @@ def dml_irm_eval_learner_fixture(learner, trimming_threshold, n_rep):
4749
n_rep=n_rep,
4850
trimming_threshold=trimming_threshold)
4951
dml_irm_obj.fit()
50-
res_manual = dml_irm_obj.evaluate_learners()
52+
res_manual = dml_irm_obj.evaluate_learners(learners=['ml_g0', 'ml_g1'])
53+
res_manual['ml_m'] = dml_irm_obj.evaluate_learners(learners=['ml_m'], metric=_logloss)['ml_m']
5154

52-
res_dict = {'rmses': dml_irm_obj.rmses,
53-
'rmses_manual': res_manual
55+
res_dict = {'nuisance_loss': dml_irm_obj.nuisance_loss,
56+
'nuisance_loss_manual': res_manual
5457
}
5558
return res_dict
5659

5760

5861
@pytest.mark.ci
5962
def test_dml_irm_eval_learner(dml_irm_eval_learner_fixture, n_rep):
60-
assert dml_irm_eval_learner_fixture['rmses_manual']['ml_g0'].shape == (n_rep, 1)
61-
assert dml_irm_eval_learner_fixture['rmses_manual']['ml_g1'].shape == (n_rep, 1)
62-
assert dml_irm_eval_learner_fixture['rmses_manual']['ml_m'].shape == (n_rep, 1)
63+
assert dml_irm_eval_learner_fixture['nuisance_loss_manual']['ml_g0'].shape == (n_rep, 1)
64+
assert dml_irm_eval_learner_fixture['nuisance_loss_manual']['ml_g1'].shape == (n_rep, 1)
65+
assert dml_irm_eval_learner_fixture['nuisance_loss_manual']['ml_m'].shape == (n_rep, 1)
6366

64-
assert np.allclose(dml_irm_eval_learner_fixture['rmses_manual']['ml_g0'],
65-
dml_irm_eval_learner_fixture['rmses']['ml_g0'],
67+
assert np.allclose(dml_irm_eval_learner_fixture['nuisance_loss_manual']['ml_g0'],
68+
dml_irm_eval_learner_fixture['nuisance_loss']['ml_g0'],
6669
rtol=1e-9, atol=1e-4)
67-
assert np.allclose(dml_irm_eval_learner_fixture['rmses_manual']['ml_g1'],
68-
dml_irm_eval_learner_fixture['rmses']['ml_g1'],
70+
assert np.allclose(dml_irm_eval_learner_fixture['nuisance_loss_manual']['ml_g1'],
71+
dml_irm_eval_learner_fixture['nuisance_loss']['ml_g1'],
6972
rtol=1e-9, atol=1e-4)
70-
assert np.allclose(dml_irm_eval_learner_fixture['rmses_manual']['ml_m'],
71-
dml_irm_eval_learner_fixture['rmses']['ml_m'],
73+
assert np.allclose(dml_irm_eval_learner_fixture['nuisance_loss_manual']['ml_m'],
74+
dml_irm_eval_learner_fixture['nuisance_loss']['ml_m'],
7275
rtol=1e-9, atol=1e-4)

doubleml/tests/test_return_types.py

Lines changed: 44 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -349,50 +349,50 @@ def test_stored_nuisance_targets():
349349

350350

351351
@pytest.mark.ci
352-
def test_rmses():
353-
assert plr_obj.rmses['ml_l'].shape == (n_rep, n_treat)
354-
assert plr_obj.rmses['ml_m'].shape == (n_rep, n_treat)
355-
356-
assert pliv_obj.rmses['ml_l'].shape == (n_rep, n_treat)
357-
assert pliv_obj.rmses['ml_m'].shape == (n_rep, n_treat)
358-
assert pliv_obj.rmses['ml_r'].shape == (n_rep, n_treat)
359-
360-
assert irm_obj.rmses['ml_g0'].shape == (n_rep, n_treat)
361-
assert irm_obj.rmses['ml_g1'].shape == (n_rep, n_treat)
362-
assert irm_obj.rmses['ml_m'].shape == (n_rep, n_treat)
363-
364-
assert iivm_obj.rmses['ml_g0'].shape == (n_rep, n_treat)
365-
assert iivm_obj.rmses['ml_g1'].shape == (n_rep, n_treat)
366-
assert iivm_obj.rmses['ml_m'].shape == (n_rep, n_treat)
367-
assert iivm_obj.rmses['ml_r0'].shape == (n_rep, n_treat)
368-
assert iivm_obj.rmses['ml_r1'].shape == (n_rep, n_treat)
369-
370-
assert cvar_obj.rmses['ml_g'].shape == (n_rep, n_treat)
371-
assert cvar_obj.rmses['ml_m'].shape == (n_rep, n_treat)
372-
373-
assert pq_obj.rmses['ml_g'].shape == (n_rep, n_treat)
374-
assert pq_obj.rmses['ml_m'].shape == (n_rep, n_treat)
375-
376-
assert lpq_obj.rmses['ml_g_du_z0'].shape == (n_rep, n_treat)
377-
assert lpq_obj.rmses['ml_g_du_z1'].shape == (n_rep, n_treat)
378-
assert lpq_obj.rmses['ml_m_z'].shape == (n_rep, n_treat)
379-
assert lpq_obj.rmses['ml_m_d_z0'].shape == (n_rep, n_treat)
380-
assert lpq_obj.rmses['ml_m_d_z1'].shape == (n_rep, n_treat)
381-
382-
assert did_obj.rmses['ml_g0'].shape == (n_rep, n_treat)
383-
assert did_obj.rmses['ml_g1'].shape == (n_rep, n_treat)
384-
assert did_obj.rmses['ml_m'].shape == (n_rep, n_treat)
385-
386-
assert did_cs_obj.rmses['ml_g_d0_t0'].shape == (n_rep, n_treat)
387-
assert did_cs_obj.rmses['ml_g_d0_t1'].shape == (n_rep, n_treat)
388-
assert did_cs_obj.rmses['ml_g_d1_t0'].shape == (n_rep, n_treat)
389-
assert did_cs_obj.rmses['ml_g_d1_t1'].shape == (n_rep, n_treat)
390-
assert did_cs_obj.rmses['ml_m'].shape == (n_rep, n_treat)
391-
392-
assert ssm_obj.rmses['ml_g_d0'].shape == (n_rep, n_treat)
393-
assert ssm_obj.rmses['ml_g_d1'].shape == (n_rep, n_treat)
394-
assert ssm_obj.rmses['ml_m'].shape == (n_rep, n_treat)
395-
assert ssm_obj.rmses['ml_pi'].shape == (n_rep, n_treat)
352+
def test_nuisance_loss():
353+
assert plr_obj.nuisance_loss['ml_l'].shape == (n_rep, n_treat)
354+
assert plr_obj.nuisance_loss['ml_m'].shape == (n_rep, n_treat)
355+
356+
assert pliv_obj.nuisance_loss['ml_l'].shape == (n_rep, n_treat)
357+
assert pliv_obj.nuisance_loss['ml_m'].shape == (n_rep, n_treat)
358+
assert pliv_obj.nuisance_loss['ml_r'].shape == (n_rep, n_treat)
359+
360+
assert irm_obj.nuisance_loss['ml_g0'].shape == (n_rep, n_treat)
361+
assert irm_obj.nuisance_loss['ml_g1'].shape == (n_rep, n_treat)
362+
assert irm_obj.nuisance_loss['ml_m'].shape == (n_rep, n_treat)
363+
364+
assert iivm_obj.nuisance_loss['ml_g0'].shape == (n_rep, n_treat)
365+
assert iivm_obj.nuisance_loss['ml_g1'].shape == (n_rep, n_treat)
366+
assert iivm_obj.nuisance_loss['ml_m'].shape == (n_rep, n_treat)
367+
assert iivm_obj.nuisance_loss['ml_r0'].shape == (n_rep, n_treat)
368+
assert iivm_obj.nuisance_loss['ml_r1'].shape == (n_rep, n_treat)
369+
370+
assert cvar_obj.nuisance_loss['ml_g'].shape == (n_rep, n_treat)
371+
assert cvar_obj.nuisance_loss['ml_m'].shape == (n_rep, n_treat)
372+
373+
assert pq_obj.nuisance_loss['ml_g'].shape == (n_rep, n_treat)
374+
assert pq_obj.nuisance_loss['ml_m'].shape == (n_rep, n_treat)
375+
376+
assert lpq_obj.nuisance_loss['ml_g_du_z0'].shape == (n_rep, n_treat)
377+
assert lpq_obj.nuisance_loss['ml_g_du_z1'].shape == (n_rep, n_treat)
378+
assert lpq_obj.nuisance_loss['ml_m_z'].shape == (n_rep, n_treat)
379+
assert lpq_obj.nuisance_loss['ml_m_d_z0'].shape == (n_rep, n_treat)
380+
assert lpq_obj.nuisance_loss['ml_m_d_z1'].shape == (n_rep, n_treat)
381+
382+
assert did_obj.nuisance_loss['ml_g0'].shape == (n_rep, n_treat)
383+
assert did_obj.nuisance_loss['ml_g1'].shape == (n_rep, n_treat)
384+
assert did_obj.nuisance_loss['ml_m'].shape == (n_rep, n_treat)
385+
386+
assert did_cs_obj.nuisance_loss['ml_g_d0_t0'].shape == (n_rep, n_treat)
387+
assert did_cs_obj.nuisance_loss['ml_g_d0_t1'].shape == (n_rep, n_treat)
388+
assert did_cs_obj.nuisance_loss['ml_g_d1_t0'].shape == (n_rep, n_treat)
389+
assert did_cs_obj.nuisance_loss['ml_g_d1_t1'].shape == (n_rep, n_treat)
390+
assert did_cs_obj.nuisance_loss['ml_m'].shape == (n_rep, n_treat)
391+
392+
assert ssm_obj.nuisance_loss['ml_g_d0'].shape == (n_rep, n_treat)
393+
assert ssm_obj.nuisance_loss['ml_g_d1'].shape == (n_rep, n_treat)
394+
assert ssm_obj.nuisance_loss['ml_m'].shape == (n_rep, n_treat)
395+
assert ssm_obj.nuisance_loss['ml_pi'].shape == (n_rep, n_treat)
396396

397397

398398
@pytest.mark.ci

doubleml/utils/_estimation.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from sklearn.base import clone
77
from sklearn.preprocessing import LabelEncoder
88
from sklearn.model_selection import KFold, GridSearchCV, RandomizedSearchCV
9-
from sklearn.metrics import root_mean_squared_error
9+
from sklearn.metrics import root_mean_squared_error, log_loss
1010

1111
from statsmodels.nonparametric.kde import KDEUnivariate
1212

@@ -204,6 +204,12 @@ def _rmse(y_true, y_pred):
204204
return rmse
205205

206206

207+
def _logloss(y_true, y_pred):
208+
subset = np.logical_not(np.isnan(y_true))
209+
logloss = log_loss(y_true[subset], y_pred[subset])
210+
return logloss
211+
212+
207213
def _predict_zero_one_propensity(learner, X):
208214
pred_proba = learner.predict_proba(X)
209215
if pred_proba.shape[1] == 2:

0 commit comments

Comments
 (0)