@@ -31,9 +31,13 @@ class ROC_AUC(EpochMetric):
3131
3232 Args:
3333 output_transform (callable, optional): a callable that is used to transform the
34- :class:`~ignite.engine.Engine`'s `process_function`'s output into the
34+ :class:`~ignite.engine.Engine`'s `` process_function` `'s output into the
3535 form expected by the metric. This can be useful if, for example, you have a multi-output model and
3636 you want to compute the metric with respect to one of the outputs.
37+ check_compute_fn (bool): Optional default False. If True, `sklearn.metrics.roc_curve
38+ <http://scikit-learn.org/stable/modules/generated/sklearn.metrics.roc_auc_score.html#
39+ sklearn.metrics.roc_auc_score>`_ is run on the first batch of data to ensure there are
40+ no issues. User will be warned in case there are any issues computing the function.
3741
3842 ROC_AUC expects y to be comprised of 0's and 1's. y_pred must either be probability estimates or confidence
3943 values. To apply an activation to y_pred, use output_transform as shown below:
@@ -49,8 +53,10 @@ def activated_output_transform(output):
4953
5054 """
5155
52- def __init__ (self , output_transform = lambda x : x ):
53- super (ROC_AUC , self ).__init__ (roc_auc_compute_fn , output_transform = output_transform )
56+ def __init__ (self , output_transform = lambda x : x , check_compute_fn : bool = False ):
57+ super (ROC_AUC , self ).__init__ (
58+ roc_auc_compute_fn , output_transform = output_transform , check_compute_fn = check_compute_fn
59+ )
5460
5561
5662class RocCurve (EpochMetric ):
@@ -61,9 +67,13 @@ class RocCurve(EpochMetric):
6167
6268 Args:
6369 output_transform (callable, optional): a callable that is used to transform the
64- :class:`~ignite.engine.Engine`'s `process_function`'s output into the
70+ :class:`~ignite.engine.Engine`'s `` process_function` `'s output into the
6571 form expected by the metric. This can be useful if, for example, you have a multi-output model and
6672 you want to compute the metric with respect to one of the outputs.
73+ check_compute_fn (bool): Optional default False. If True, `sklearn.metrics.roc_curve
74+ <http://scikit-learn.org/stable/modules/generated/sklearn.metrics.roc_curve.html#
75+ sklearn.metrics.roc_curve>`_ is run on the first batch of data to ensure there are
76+ no issues. User will be warned in case there are any issues computing the function.
6777
6878 RocCurve expects y to be comprised of 0's and 1's. y_pred must either be probability estimates or confidence
6979 values. To apply an activation to y_pred, use output_transform as shown below:
@@ -79,5 +89,7 @@ def activated_output_transform(output):
7989
8090 """
8191
82- def __init__ (self , output_transform = lambda x : x ):
83- super (RocCurve , self ).__init__ (roc_auc_curve_compute_fn , output_transform = output_transform )
92+ def __init__ (self , output_transform = lambda x : x , check_compute_fn : bool = False ):
93+ super (RocCurve , self ).__init__ (
94+ roc_auc_curve_compute_fn , output_transform = output_transform , check_compute_fn = check_compute_fn
95+ )
0 commit comments