@@ -50,11 +50,12 @@ def __init__(self,
5050 # initialize learners and parameters which are set model specific
5151 self ._learner = None
5252 self ._params = None
53+ self ._is_classifier = {}
5354
5455 # initialize predictions and target to None which are only stored if method fit is called with store_predictions=True
5556 self ._predictions = None
5657 self ._nuisance_targets = None
57- self ._rmses = None
58+ self ._nuisance_loss = None
5859
5960 # initialize models to None which are only stored if method fit is called with store_models=True
6061 self ._models = None
@@ -119,10 +120,18 @@ def __str__(self):
119120 learner_info = ''
120121 for key , value in self .learner .items ():
121122 learner_info += f'Learner { key } : { str (value )} \n '
122- if self .rmses is not None :
123+ if self .nuisance_loss is not None :
123124 learner_info += 'Out-of-sample Performance:\n '
124- for learner in self .params_names :
125- learner_info += f'Learner { learner } RMSE: { self .rmses [learner ]} \n '
125+ is_classifier = [value for value in self ._is_classifier .values ()]
126+ is_regressor = [not value for value in is_classifier ]
127+ if any (is_regressor ):
128+ learner_info += 'Regression:\n '
129+ for learner in [key for key , value in self ._is_classifier .items () if value is False ]:
130+ learner_info += f'Learner { learner } RMSE: { self .nuisance_loss [learner ]} \n '
131+ if any (is_classifier ):
132+ learner_info += 'Classification:\n '
133+ for learner in [key for key , value in self ._is_classifier .items () if value is True ]:
134+ learner_info += f'Learner { learner } Log Loss: { self .nuisance_loss [learner ]} \n '
126135
127136 if self ._is_cluster_data :
128137 resampling_info = f'No. folds per cluster: { self ._n_folds_per_cluster } \n ' \
@@ -234,11 +243,11 @@ def nuisance_targets(self):
234243 return self ._nuisance_targets
235244
236245 @property
237- def rmses (self ):
246+ def nuisance_loss (self ):
238247 """
239- The root-mean-squared-errors of the nuisance models .
248+ The losses of the nuisance models ( root-mean-squared-errors or logloss) .
240249 """
241- return self ._rmses
250+ return self ._nuisance_loss
242251
243252 @property
244253 def models (self ):
@@ -915,8 +924,8 @@ def _check_fit(self, n_jobs_cv, store_predictions, external_predictions, store_m
915924 raise NotImplementedError (f"External predictions not implemented for { self .__class__ .__name__ } ." )
916925
917926 def _initalize_fit (self , store_predictions , store_models ):
918- # initialize rmse arrays for nuisance functions evaluation
919- self ._initialize_rmses ()
927+ # initialize loss arrays for nuisance functions evaluation
928+ self ._initialize_nuisance_loss ()
920929
921930 if store_predictions :
922931 self ._initialize_predictions_and_targets ()
@@ -942,8 +951,8 @@ def _fit_nuisance_and_score_elements(self, n_jobs_cv, store_predictions, externa
942951
943952 self ._set_score_elements (score_elements , self ._i_rep , self ._i_treat )
944953
945- # calculate rmses and store predictions and targets of the nuisance models
946- self ._calc_rmses (preds ['predictions' ], preds ['targets' ])
954+ # calculate nuisance losses and store predictions and targets of the nuisance models
955+ self ._calc_nuisance_loss (preds ['predictions' ], preds ['targets' ])
947956 if store_predictions :
948957 self ._store_predictions_and_targets (preds ['predictions' ], preds ['targets' ])
949958 if store_models :
@@ -1001,9 +1010,11 @@ def _initialize_predictions_and_targets(self):
10011010 self ._nuisance_targets = {learner : np .full ((self ._dml_data .n_obs , self .n_rep , self ._dml_data .n_coefs ), np .nan )
10021011 for learner in self .params_names }
10031012
1004- def _initialize_rmses (self ):
1005- self ._rmses = {learner : np .full ((self .n_rep , self ._dml_data .n_coefs ), np .nan )
1006- for learner in self .params_names }
1013+ def _initialize_nuisance_loss (self ):
1014+ self ._nuisance_loss = {
1015+ learner : np .full ((self .n_rep , self ._dml_data .n_coefs ), np .nan )
1016+ for learner in self .params_names
1017+ }
10071018
10081019 def _initialize_models (self ):
10091020 self ._models = {learner : {treat_var : [None ] * self .n_rep for treat_var in self ._dml_data .d_cols }
@@ -1014,13 +1025,33 @@ def _store_predictions_and_targets(self, preds, targets):
10141025 self ._predictions [learner ][:, self ._i_rep , self ._i_treat ] = preds [learner ]
10151026 self ._nuisance_targets [learner ][:, self ._i_rep , self ._i_treat ] = targets [learner ]
10161027
1017- def _calc_rmses (self , preds , targets ):
1028+ def _calc_nuisance_loss (self , preds , targets ):
1029+ self ._is_classifier = {key : False for key in self .params_names }
10181030 for learner in self .params_names :
1031+ # check if the learner is a classifier
1032+ learner_keys = [key for key in self ._learner .keys () if key in learner ]
1033+ assert len (learner_keys ) == 1
1034+ self ._is_classifier [learner ] = self ._check_learner (
1035+ self ._learner [learner_keys [0 ]],
1036+ learner ,
1037+ regressor = True , classifier = True
1038+ )
1039+
10191040 if targets [learner ] is None :
1020- self ._rmses [learner ][self ._i_rep , self ._i_treat ] = np .nan
1041+ self ._nuisance_loss [learner ][self ._i_rep , self ._i_treat ] = np .nan
10211042 else :
1022- sq_error = np .power (targets [learner ] - preds [learner ], 2 )
1023- self ._rmses [learner ][self ._i_rep , self ._i_treat ] = np .sqrt (np .nanmean (sq_error , axis = 0 ))
1043+ learner_keys = [key for key in self ._learner .keys () if key in learner ]
1044+ assert len (learner_keys ) == 1
1045+
1046+ if self ._is_classifier [learner ]:
1047+ predictions = np .clip (preds [learner ], 1e-15 , 1 - 1e-15 )
1048+ logloss = targets [learner ] * np .log (predictions ) + (1 - targets [learner ]) * np .log (1 - predictions )
1049+ loss = - np .nanmean (logloss , axis = 0 )
1050+ else :
1051+ sq_error = np .power (targets [learner ] - preds [learner ], 2 )
1052+ loss = np .sqrt (np .nanmean (sq_error , axis = 0 ))
1053+
1054+ self ._nuisance_loss [learner ][self ._i_rep , self ._i_treat ] = loss
10241055
10251056 def _store_models (self , models ):
10261057 for learner in self .params_names :
0 commit comments