diff --git a/examples/getting_started/plot_skore_getting_started.py b/examples/getting_started/plot_skore_getting_started.py index 1be7e1c0b..c0a01b99d 100644 --- a/examples/getting_started/plot_skore_getting_started.py +++ b/examples/getting_started/plot_skore_getting_started.py @@ -162,25 +162,34 @@ # %% from skore import ComparisonReport -comparator = ComparisonReport(reports=[log_reg_report, rf_report]) +comparison_report = ComparisonReport(reports=[log_reg_report, rf_report]) # %% # As for the :class:`~skore.EstimatorReport` and the # :class:`~skore.CrossValidationReport`, we have a helper: # %% -comparator.help() +comparison_report.help() # %% # Let us display the result of our benchmark: # %% -benchmark_metrics = comparator.metrics.report_metrics() +benchmark_metrics = comparison_report.metrics.report_metrics() benchmark_metrics # %% # We have the result of our benchmark. +# %% +# We display the ROC curve for the two estimator reports we want to compare, by +# superimposing them on the same figure: + +# %% +comparison_report.metrics.roc().plot() +plt.tight_layout() + + # %% # Train-test split with skore # ^^^^^^^^^^^^^^^^^^^^^^^^^^^ diff --git a/skore/src/skore/sklearn/__init__.py b/skore/src/skore/sklearn/__init__.py index f1abb357c..97ea84527 100644 --- a/skore/src/skore/sklearn/__init__.py +++ b/skore/src/skore/sklearn/__init__.py @@ -11,11 +11,11 @@ from skore.sklearn.train_test_split.train_test_split import train_test_split __all__ = [ - "train_test_split", + "ComparisonReport", "CrossValidationReport", "EstimatorReport", - "ComparisonReport", - "RocCurveDisplay", "PrecisionRecallCurveDisplay", "PredictionErrorDisplay", + "RocCurveDisplay", + "train_test_split", ] diff --git a/skore/src/skore/sklearn/_comparison/metrics_accessor.py b/skore/src/skore/sklearn/_comparison/metrics_accessor.py index 9a9403e52..46d965aa9 100644 --- a/skore/src/skore/sklearn/_comparison/metrics_accessor.py +++ b/skore/src/skore/sklearn/_comparison/metrics_accessor.py @@ -5,8 +5,14 @@ from sklearn.utils.metaestimators import available_if from skore.externals._pandas_accessors import DirNamesMixin -from skore.sklearn._base import _BaseAccessor +from skore.sklearn._base import _BaseAccessor, _get_cached_response_values +from skore.sklearn._comparison.precision_recall_curve_display import ( + PrecisionRecallCurveDisplay, +) +from skore.sklearn._comparison.prediction_error_display import PredictionErrorDisplay +from skore.sklearn._comparison.roc_curve_display import RocCurveDisplay from skore.utils._accessor import _check_supported_ml_task +from skore.utils._index import flatten_multi_index from skore.utils._progress_bar import progress_decorator @@ -42,9 +48,10 @@ def report_metrics( y=None, scoring=None, scoring_names=None, - pos_label=None, scoring_kwargs=None, + pos_label=None, indicator_favorability=False, + flat_index=False, ): """Report a set of metrics for the estimators. @@ -78,16 +85,20 @@ def report_metrics( Used to overwrite the default scoring names in the report. It should be of the same length as the ``scoring`` parameter. - pos_label : int, float, bool or str, default=None - The positive class. - scoring_kwargs : dict, default=None The keyword arguments to pass to the scoring functions. + pos_label : int, float, bool or str, default=None + The positive class. + indicator_favorability : bool, default=False Whether or not to add an indicator of the favorability of the metric as an extra column in the returned DataFrame. + flat_index : bool, default=False + Whether to flatten the `MultiIndex` columns. Flat index will always be lower + case, do not include spaces and remove the hash symbol to ease indexing. + Returns ------- pd.DataFrame @@ -129,7 +140,7 @@ def report_metrics( Precision 0.96... 0.96... Recall 0.97... 0.97... """ - return self._compute_metric_scores( + results = self._compute_metric_scores( report_metric_name="report_metrics", data_source=data_source, X=X, @@ -140,6 +151,12 @@ def report_metrics( scoring_names=scoring_names, indicator_favorability=indicator_favorability, ) + if flat_index: + if isinstance(results.columns, pd.MultiIndex): + results.columns = flatten_multi_index(results.columns) + if isinstance(results.index, pd.MultiIndex): + results.index = flatten_multi_index(results.index) + return results @progress_decorator(description="Compute metric for each split") def _compute_metric_scores( @@ -1096,3 +1113,328 @@ def __repr__(self): class_name="skore.ComparisonReport.metrics", help_method_name="report.metrics.help()", ) + + @progress_decorator(description="Computing predictions for display") + def _get_display( + self, + *, + X, + y, + data_source, + response_method, + display_class, + display_kwargs, + ): + """Get the display from the cache or compute it. + + Parameters + ---------- + X : array-like of shape (n_samples, n_features) + The data. + + y : array-like of shape (n_samples,) + The target. + + data_source : {"test", "train", "X_y"}, default="test" + The data source to use. + + - "test" : use the test set provided when creating the report. + - "train" : use the train set provided when creating the report. + - "X_y" : use the provided `X` and `y` to compute the metric. + + response_method : str + The response method. + + display_class : class + The display class. + + display_kwargs : dict + The display kwargs used by `display_class._from_predictions`. + + Returns + ------- + display : display_class + The display. + """ + cache_key = (self._parent._hash, display_class.__name__) + cache_key += tuple(display_kwargs.values()) + cache_key += (data_source,) + + progress = self._progress_info["current_progress"] + main_task = self._progress_info["current_task"] + total_estimators = len(self._parent.estimator_reports_) + progress.update(main_task, total=total_estimators) + + if cache_key in self._parent._cache: + display = self._parent._cache[cache_key] + else: + y_true, y_pred = [], [] + + for report in self._parent.estimator_reports_: + report_X, report_y, _ = report.metrics._get_X_y_and_data_source_hash( + data_source=data_source, + X=X, + y=y, + ) + + y_true.append(report_y) + y_pred.append( + _get_cached_response_values( + cache=report._cache, + estimator_hash=report._hash, + estimator=report._estimator, + X=report_X, + response_method=response_method, + data_source=data_source, + data_source_hash=None, + pos_label=display_kwargs.get("pos_label", None), + ) + ) + progress.update(main_task, advance=1, refresh=True) + + display = display_class._from_predictions( + y_true, + y_pred, + estimators=[r.estimator_ for r in self._parent.estimator_reports_], + estimator_names=self._parent.report_names_, + ml_task=self._parent._ml_task, + data_source=data_source, + **display_kwargs, + ) + self._parent._cache[cache_key] = display + + return display + + @available_if( + _check_supported_ml_task( + supported_ml_tasks=["binary-classification", "multiclass-classification"] + ) + ) + def roc(self, *, data_source="test", X=None, y=None, pos_label=None, ax=None): + """Plot the ROC curve. + + Parameters + ---------- + data_source : {"test", "train", "X_y"}, default="test" + The data source to use. + + - "test" : use the test set provided when creating the report. + - "train" : use the train set provided when creating the report. + - "X_y" : use the provided `X` and `y` to compute the metric. + + X : array-like of shape (n_samples, n_features), default=None + New data on which to compute the metric. By default, we use the validation + set provided when creating the report. + + y : array-like of shape (n_samples,), default=None + New target on which to compute the metric. By default, we use the target + provided when creating the report. + + pos_label : int, float, bool or str, default=None + The positive class. + + Returns + ------- + RocCurveDisplay + The ROC curve display. + + Examples + -------- + >>> from sklearn.datasets import load_breast_cancer + >>> from sklearn.linear_model import LogisticRegression + >>> from sklearn.model_selection import train_test_split + >>> from skore import ComparisonReport, EstimatorReport + >>> X, y = load_breast_cancer(return_X_y=True) + >>> X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=42) + >>> estimator_1 = LogisticRegression(max_iter=10000, random_state=42) + >>> estimator_report_1 = EstimatorReport( + ... estimator_1, + ... X_train=X_train, + ... y_train=y_train, + ... X_test=X_test, + ... y_test=y_test, + ... ) + >>> estimator_2 = LogisticRegression(max_iter=10000, random_state=43) + >>> estimator_report_2 = EstimatorReport( + ... estimator_2, + ... X_train=X_train, + ... y_train=y_train, + ... X_test=X_test, + ... y_test=y_test, + ... ) + >>> comparison_report = ComparisonReport( + ... [estimator_report_1, estimator_report_2] + ... ) + >>> display = comparison_report.metrics.roc() + >>> display.plot() + """ + response_method = ("predict_proba", "decision_function") + display_kwargs = {"pos_label": pos_label} + return self._get_display( + X=X, + y=y, + data_source=data_source, + response_method=response_method, + display_class=RocCurveDisplay, + display_kwargs=display_kwargs, + ) + + @available_if( + _check_supported_ml_task( + supported_ml_tasks=["binary-classification", "multiclass-classification"] + ) + ) + def precision_recall(self, *, data_source="test", X=None, y=None, pos_label=None): + """Plot the precision-recall curve. + + Parameters + ---------- + data_source : {"test", "train", "X_y"}, default="test" + The data source to use. + + - "test" : use the test set provided when creating the report. + - "train" : use the train set provided when creating the report. + - "X_y" : use the provided `X` and `y` to compute the metric. + + X : array-like of shape (n_samples, n_features), default=None + New data on which to compute the metric. By default, we use the validation + set provided when creating the report. + + y : array-like of shape (n_samples,), default=None + New target on which to compute the metric. By default, we use the target + provided when creating the report. + + pos_label : int, float, bool or str, default=None + The positive class. + + Returns + ------- + PrecisionRecallCurveDisplay + The precision-recall curve display. + + Examples + -------- + >>> from sklearn.datasets import load_breast_cancer + >>> from sklearn.linear_model import LogisticRegression + >>> from sklearn.model_selection import train_test_split + >>> from skore import ComparisonReport, EstimatorReport + >>> X, y = load_breast_cancer(return_X_y=True) + >>> X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=42) + >>> estimator_1 = LogisticRegression(max_iter=10000, random_state=42) + >>> estimator_report_1 = EstimatorReport( + ... estimator_1, + ... X_train=X_train, + ... y_train=y_train, + ... X_test=X_test, + ... y_test=y_test, + ... ) + >>> estimator_2 = LogisticRegression(max_iter=10000, random_state=43) + >>> estimator_report_2 = EstimatorReport( + ... estimator_2, + ... X_train=X_train, + ... y_train=y_train, + ... X_test=X_test, + ... y_test=y_test, + ... ) + >>> comparison_report = ComparisonReport( + ... [estimator_report_1, estimator_report_2] + ... ) + >>> display = comparison_report.metrics.precision_recall() + >>> display.plot() + """ + response_method = ("predict_proba", "decision_function") + display_kwargs = {"pos_label": pos_label} + return self._get_display( + X=X, + y=y, + data_source=data_source, + response_method=response_method, + display_class=PrecisionRecallCurveDisplay, + display_kwargs=display_kwargs, + ) + + @available_if(_check_supported_ml_task(supported_ml_tasks=["regression"])) + def prediction_error( + self, + *, + data_source="test", + X=None, + y=None, + subsample=1_000, + random_state=None, + ): + """Plot the prediction error of a regression model. + + Extra keyword arguments will be passed to matplotlib's `plot`. + + Parameters + ---------- + data_source : {"test", "train", "X_y"}, default="test" + The data source to use. + + - "test" : use the test set provided when creating the report. + - "train" : use the train set provided when creating the report. + - "X_y" : use the provided `X` and `y` to compute the metric. + + X : array-like of shape (n_samples, n_features), default=None + New data on which to compute the metric. By default, we use the validation + set provided when creating the report. + + y : array-like of shape (n_samples,), default=None + New target on which to compute the metric. By default, we use the target + provided when creating the report. + + subsample : float, int or None, default=1_000 + Sampling the samples to be shown on the scatter plot. If `float`, + it should be between 0 and 1 and represents the proportion of the + original dataset. If `int`, it represents the number of samples + display on the scatter plot. If `None`, no subsampling will be + applied. by default, 1,000 samples or less will be displayed. + + random_state : int, default=None + The random state to use for the subsampling. + + Returns + ------- + PredictionErrorDisplay + The prediction error display. + + Examples + -------- + >>> from sklearn.datasets import load_diabetes + >>> from sklearn.linear_model import Ridge + >>> from sklearn.model_selection import train_test_split + >>> from skore import ComparisonReport, EstimatorReport + >>> X, y = load_diabetes(return_X_y=True) + >>> X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=42) + >>> estimator_1 = Ridge(random_state=42) + >>> estimator_report_1 = EstimatorReport( + ... estimator_1, + ... X_train=X_train, + ... y_train=y_train, + ... X_test=X_test, + ... y_test=y_test, + ... ) + >>> estimator_2 = Ridge(random_state=43) + >>> estimator_report_2 = EstimatorReport( + ... estimator_2, + ... X_train=X_train, + ... y_train=y_train, + ... X_test=X_test, + ... y_test=y_test, + ... ) + >>> comparison_report = ComparisonReport( + ... [estimator_report_1, estimator_report_2] + ... ) + >>> display = comparison_report.metrics.prediction_error() + >>> display.plot(kind="actual_vs_predicted") + """ + display_kwargs = {"subsample": subsample, "random_state": random_state} + return self._get_display( + X=X, + y=y, + data_source=data_source, + response_method="predict", + display_class=PredictionErrorDisplay, + display_kwargs=display_kwargs, + ) diff --git a/skore/src/skore/sklearn/_comparison/precision_recall_curve_display.py b/skore/src/skore/sklearn/_comparison/precision_recall_curve_display.py new file mode 100644 index 000000000..15b918d99 --- /dev/null +++ b/skore/src/skore/sklearn/_comparison/precision_recall_curve_display.py @@ -0,0 +1,309 @@ +from collections import defaultdict + +import matplotlib.pyplot as plt +import numpy as np +from sklearn.metrics import average_precision_score, precision_recall_curve +from sklearn.preprocessing import LabelBinarizer + +from skore.sklearn._comparison.roc_curve_display import LINESTYLE +from skore.sklearn._plot.utils import ( + HelpDisplayMixin, + _ClassifierCurveDisplayMixin, + _despine_matplotlib_axis, + sample_mpl_colormap, +) + + +class PrecisionRecallCurveDisplay(HelpDisplayMixin, _ClassifierCurveDisplayMixin): + """Precision Recall visualization. + + An instance of this class is should created by + `ComparisonReport.metrics.precision_recall()`. + You should not create an instance of this class directly. + + Parameters + ---------- + precision : dict of list of ndarray + Precision values. The structure is: + + - for binary classification: + - the key is the positive label. + - the value is a list of `ndarray`, each `ndarray` being the precision. + - for multiclass classification: + - the key is the class of interest in an OvR fashion. + - the value is a list of `ndarray`, each `ndarray` being the precision. + + recall : dict of list of ndarray + Recall values. The structure is: + + - for binary classification: + - the key is the positive label. + - the value is a list of `ndarray`, each `ndarray` being the recall. + - for multiclass classification: + - the key is the class of interest in an OvR fashion. + - the value is a list of `ndarray`, each `ndarray` being the recall. + + average_precision : dict of list of float + Average precision. The structure is: + + - for binary classification: + - the key is the positive label. + - the value is a list of `float`, each `float` being the average + precision. + - for multiclass classification: + - the key is the class of interest in an OvR fashion. + - the value is a list of `float`, each `float` being the average + precision. + + estimator_names : list[str] + Name of the estimators. + + pos_label : int, float, bool or str, default=None + The class considered as the positive class. If None, the class will not + be shown in the legend. + + data_source : {"train", "test", "X_y"} + The data source used to compute the precision recall curve. + + Attributes + ---------- + ax_ : matplotlib Axes + Axes with precision recall curve, available after calling `plot`. + + figure_ : matplotlib Figure + Figure containing the curve, available after calling `plot`. + + lines_ : list of matplotlib lines + The lines of the precision recall curve, available after calling `plot`. + """ + + def __init__( + self, + precision, + recall, + *, + average_precision, + estimator_names, + ml_task, + data_source, + pos_label=None, + ): + self.precision = precision + self.recall = recall + self.average_precision = average_precision + self.estimator_names = estimator_names + self.ml_task = ml_task + self.data_source = data_source + self.pos_label = pos_label + + def plot( + self, + ax=None, + *, + despine=True, + ): + """Plot visualization. + + Parameters + ---------- + ax : Matplotlib Axes, default=None + Axes object to plot on. If `None`, a new figure and axes is + created. + + despine : bool, default=True + Whether to remove the top and right spines from the plot. + + Notes + ----- + The average precision (cf. :func:`~sklearn.metrics.average_precision_score`) + in scikit-learn is computed without any interpolation. To be consistent + with this metric, the precision-recall curve is plotted without any + interpolation as well (step-wise style). + """ + self.figure_, self.ax_ = (ax.figure, ax) if ax else plt.subplots() + self.lines_ = [] + + if self.ml_task == "binary-classification": + for report_idx, report_name in enumerate(self.estimator_names): + precision = self.precision[self.pos_label][report_idx] + recall = self.recall[self.pos_label][report_idx] + average_precision = self.average_precision[self.pos_label][report_idx] + + self.lines_ += self.ax_.plot( + recall, + precision, + drawstyle="steps-post", + alpha=0.6, + label=( + f"{report_name} #{report_idx + 1} " + f"(AP = {average_precision:0.2f})" + ), + ) + + info_pos_label = ( + f"\n(Positive label: {self.pos_label})" + if self.pos_label is not None + else "" + ) + else: # multiclass-classification + info_pos_label = None # irrelevant for multiclass + colors = sample_mpl_colormap( + plt.cm.tab10, + 10 if len(self.estimator_names) < 10 else len(self.estimator_names), + ) + + for report_idx, report_name in enumerate(self.estimator_names): + report_color = colors[report_idx] + + for class_idx, class_ in enumerate(self.precision): + precision = self.precision[class_][report_idx] + recall = self.recall[class_][report_idx] + average_precision_class = self.average_precision[class_] + average_precision = average_precision_class[report_idx] + class_linestyle = LINESTYLE[(class_idx % len(LINESTYLE))][1] + + self.lines_ += self.ax_.plot( + recall, + precision, + color=report_color, + linestyle=class_linestyle, + alpha=0.6, + label=( + f"{report_name} #{report_idx + 1} - class {class_} " + f"(AP = {np.mean(average_precision_class):0.2f})" + ), + ) + + xlabel = "Recall" + ylabel = "Precision" + if info_pos_label: + xlabel += info_pos_label + ylabel += info_pos_label + + self.ax_.set( + xlabel=xlabel, + xlim=(-0.01, 1.01), + ylabel=ylabel, + ylim=(-0.01, 1.01), + aspect="equal", + ) + + if despine: + _despine_matplotlib_axis(self.ax_) + + self.ax_.legend( + loc="lower right", + title=f"{self.ml_task.title()} on $\\bf{{{self.data_source}}}$ set", + ) + + @classmethod + def _from_predictions( + cls, + y_true: list[list], + y_pred: list[list], + *, + estimators: list, + estimator_names: list[str], + ml_task, + data_source, + pos_label=None, + drop_intermediate=True, + ): + """Private factory to create a PrecisionRecallCurveDisplay from predictions. + + Parameters + ---------- + y_true : array-like of shape (n_samples,) + True binary labels. + + y_pred : array-like of shape (n_samples,) + Target scores, can either be probability estimates of the positive class, + confidence values, or non-thresholded measure of decisions (as returned by + “decision_function” on some classifiers). + + estimators : list of estimator instances + The estimators from which `y_pred` is obtained. + + estimator_names : list[str] + Name of the estimators used to plot the precision recall curve. + + ml_task : {"binary-classification", "multiclass-classification"} + The machine learning task. + + data_source : {"train", "test", "X_y"} + The data source used to compute the precision recall curve. + + pos_label : int, float, bool or str, default=None + The class considered as the positive class when computing the + precision and recall metrics. + + drop_intermediate : bool, default=False + Whether to drop some suboptimal thresholds which would not appear + on a plotted precision-recall curve. This is useful in order to + create lighter precision-recall curves. + + Returns + ------- + display : PrecisionRecallCurveDisplay + """ + estimator_classes = [estimator.classes_ for estimator in estimators] + precision, recall, average_precision = (defaultdict(list) for _ in range(3)) + pos_label_validated = cls._validate_from_predictions_params( + y_true, y_pred, ml_task=ml_task, pos_label=pos_label + ) + + if ml_task == "binary-classification": + for y_true_i, y_pred_i in zip(y_true, y_pred): + precision_i, recall_i, _ = precision_recall_curve( + y_true_i, + y_pred_i, + pos_label=pos_label_validated, + drop_intermediate=drop_intermediate, + ) + + precision[pos_label_validated].append(precision_i) + recall[pos_label_validated].append(recall_i) + average_precision[pos_label_validated].append( + average_precision_score( + y_true_i, + y_pred_i, + pos_label=pos_label_validated, + ) + ) + elif ml_task == "multiclass-classification": + for y_true_i, y_pred_i, estimator_classes_i in zip( + y_true, + y_pred, + estimator_classes, + ): + label_binarizer = LabelBinarizer().fit(estimator_classes_i) + y_true_onehot_i = label_binarizer.transform(y_true_i) + + for class_idx, class_ in enumerate(estimator_classes_i): + precision_class_i, recall_class_i, _ = precision_recall_curve( + y_true_onehot_i[:, class_idx], + y_pred_i[:, class_idx], + pos_label=None, + drop_intermediate=drop_intermediate, + ) + + precision[class_].append(precision_class_i) + recall[class_].append(recall_class_i) + average_precision[class_].append( + average_precision_score( + y_true_onehot_i[:, class_idx], + y_pred_i[:, class_idx], + ) + ) + else: + raise ValueError("Only binary or multiclass classification is allowed") + + return cls( + precision=dict(precision), + recall=dict(recall), + average_precision=dict(average_precision), + estimator_names=estimator_names, + ml_task=ml_task, + pos_label=pos_label_validated, + data_source=data_source, + ) diff --git a/skore/src/skore/sklearn/_comparison/prediction_error_display.py b/skore/src/skore/sklearn/_comparison/prediction_error_display.py new file mode 100644 index 000000000..3460dc531 --- /dev/null +++ b/skore/src/skore/sklearn/_comparison/prediction_error_display.py @@ -0,0 +1,248 @@ +import numbers + +import matplotlib.pyplot as plt +import numpy as np +from sklearn.utils.validation import check_random_state + +from skore.externals._sklearn_compat import _safe_indexing +from skore.sklearn._plot.utils import ( + HelpDisplayMixin, + _despine_matplotlib_axis, +) + + +class PredictionErrorDisplay(HelpDisplayMixin): + """Prediction error visualization for comparison report. + + This tool can display "residuals vs predicted" or "actual vs predicted" + using scatter plots to qualitatively assess the behavior of a regressor, + preferably on held-out data points. + + An instance of this class is should created by + `ComparisonReport.metrics.prediction_error()`. + You should not create an instance of this class directly. + + Parameters + ---------- + y_true : list of ndarray of shape (n_samples,) + True values. + + y_pred : list of ndarray of shape (n_samples,) + Prediction values. + + estimator_names : str + Name of the estimators. + + data_source : {"train", "test", "X_y"} + The data source used to display the prediction error. + + Attributes + ---------- + line_ : matplotlib line + Optimal line representing `y_true == y_pred`. Therefore, it is a + diagonal line for `kind="predictions"` and a horizontal line for + `kind="residuals"`, available after calling `plot`. + + scatters_ : list of matplotlib scatters + The scatters of the prediction error curve, available after calling `plot`. + + ax_ : matplotlib axes + The axes on which the prediction error curve is plotted, available after calling + `plot`. + + figure_ : matplotlib figure + The figure on which the prediction error curve is plotted, available after + calling `plot`. + """ + + def __init__(self, *, y_true, y_pred, estimator_names, data_source): + self.y_true = y_true + self.y_pred = y_pred + self.estimator_names = estimator_names + self.data_source = data_source + + def plot( + self, + ax=None, + *, + kind="residual_vs_predicted", + despine=True, + ): + """Plot visualization. + + Parameters + ---------- + ax : matplotlib axes, default=None + Axes object to plot on. If `None`, a new figure and axes is + created. + + kind : {"actual_vs_predicted", "residual_vs_predicted"}, \ + default="residual_vs_predicted" + The type of plot to draw: + + - "actual_vs_predicted" draws the observed values (y-axis) vs. + the predicted values (x-axis). + - "residual_vs_predicted" draws the residuals, i.e. difference + between observed and predicted values, (y-axis) vs. the predicted + values (x-axis). + + despine : bool, default=True + Whether to remove the top and right spines from the plot. + """ + if kind == "actual_vs_predicted": + xlabel, ylabel = "Predicted values", "Actual values" + elif kind == "residual_vs_predicted": + xlabel, ylabel = "Predicted values", "Residuals (actual - predicted)" + else: + raise ValueError( + "`kind` must be 'actual_vs_predicted' or 'residual_vs_predicted'. " + f"Got {kind!r} instead." + ) + + self.figure_, self.ax_ = (ax.figure, ax) if ax else plt.subplots() + self.scatters_ = [ + self.ax_.scatter( + self.y_pred[report_idx], + ( + self.y_true[report_idx] + if kind == "actual_vs_predicted" + else (self.y_true[report_idx] - self.y_pred[report_idx]) + ), + label=f"{report_name} #{report_idx + 1}", + alpha=0.6, + s=10, + ) + for report_idx, report_name in enumerate(self.estimator_names) + ] + + x_range_perfect_pred = [np.inf, -np.inf] + y_range_perfect_pred = [np.inf, -np.inf] + for y_true, y_pred in zip(self.y_true, self.y_pred): + if kind == "actual_vs_predicted": + min_value = min(y_pred.min(), y_true.min()) + max_value = max(y_pred.max(), y_true.max()) + x_range_perfect_pred[0] = min(x_range_perfect_pred[0], min_value) + x_range_perfect_pred[1] = max(x_range_perfect_pred[1], max_value) + y_range_perfect_pred[0] = min(y_range_perfect_pred[0], min_value) + y_range_perfect_pred[1] = max(y_range_perfect_pred[1], max_value) + else: # kind == "residual_vs_predicted" + residuals = y_true - y_pred + x_range_perfect_pred[0] = min(x_range_perfect_pred[0], y_pred.min()) + x_range_perfect_pred[1] = max(x_range_perfect_pred[1], y_pred.max()) + y_range_perfect_pred[0] = min(y_range_perfect_pred[0], residuals.min()) + y_range_perfect_pred[1] = max(y_range_perfect_pred[1], residuals.max()) + + self.line_ = self.ax_.plot( + x_range_perfect_pred, + (y_range_perfect_pred if kind == "actual_vs_predicted" else [0, 0]), + color="black", + alpha=0.7, + linestyle="--", + label="Perfect predictions", + )[0] + + self.ax_.set( + aspect="equal", + xlim=x_range_perfect_pred, + ylim=y_range_perfect_pred, + xticks=np.linspace(x_range_perfect_pred[0], x_range_perfect_pred[1], num=5), + yticks=np.linspace(y_range_perfect_pred[0], y_range_perfect_pred[1], num=5), + ) + self.ax_.set(xlabel=xlabel, ylabel=ylabel) + self.ax_.legend(title=f"Regression on $\\bf{{{self.data_source}}}$ set") + + if despine: + _despine_matplotlib_axis( + self.ax_, + x_range=self.ax_.get_xlim(), + y_range=self.ax_.get_ylim(), + ) + + @classmethod + def _from_predictions( + cls, + y_true, + y_pred, + *, + estimator_names, + ml_task, + data_source, + subsample=1_000, + random_state=None, + **kwargs, + ): + """Private factory to create a PredictionErrorDisplay from predictions. + + Parameters + ---------- + y_true : list of array-like of shape (n_samples,) + True target values. + + y_pred : list of array-like of shape (n_samples,) + Predicted target values. + + estimator_names : list[str] + Name of the estimators used to plot the prediction error curve. + + ml_task : {"binary-classification", "multiclass-classification"} + The machine learning task. + + data_source : {"train", "test", "X_y"} + The data source used to compute the prediction error curve. + + subsample : float, int or None, default=1_000 + Sampling the samples to be shown on the scatter plot. If `float`, + it should be between 0 and 1 and represents the proportion of the + original dataset. If `int`, it represents the number of samples + display on the scatter plot. If `None`, no subsampling will be + applied. by default, 1000 samples or less will be displayed. + + random_state : int or RandomState, default=None + Controls the randomness when `subsample` is not `None`. + See :term:`Glossary ` for details. + + Returns + ------- + display : PredictionErrorDisplay + """ + if ml_task != "regression": + raise ValueError("Only regression is allowed") + + random_state = check_random_state(random_state) + if isinstance(subsample, numbers.Integral): + if subsample <= 0: + raise ValueError( + f"When an integer, subsample should be positive; got {subsample}." + ) + elif isinstance(subsample, numbers.Real) and (subsample <= 0 or subsample >= 1): + raise ValueError( + f"When a floating-point, subsample should be between 0 and 1; " + f"got {subsample}." + ) + + y_true_display, y_pred_display = [], [] + for y_true_i, y_pred_i in zip(y_true, y_pred): + n_samples = len(y_true_i) + if subsample is None: + subsample_ = n_samples + elif isinstance(subsample, numbers.Integral): + subsample_ = subsample + else: # subsample is a float + subsample_ = int(n_samples * subsample) + + # normalize subsample based on the number of splits + subsample_ = int(subsample_ / len(y_true)) + if subsample_ < n_samples: + indices = random_state.choice(np.arange(n_samples), size=subsample_) + y_true_display.append(_safe_indexing(y_true_i, indices, axis=0)) + y_pred_display.append(_safe_indexing(y_pred_i, indices, axis=0)) + else: + y_true_display.append(y_true_i) + y_pred_display.append(y_pred_i) + + return cls( + y_true=y_true_display, + y_pred=y_pred_display, + estimator_names=estimator_names, + data_source=data_source, + ) diff --git a/skore/src/skore/sklearn/_comparison/roc_curve_display.py b/skore/src/skore/sklearn/_comparison/roc_curve_display.py new file mode 100644 index 000000000..09fa8d80d --- /dev/null +++ b/skore/src/skore/sklearn/_comparison/roc_curve_display.py @@ -0,0 +1,318 @@ +from collections import defaultdict + +import matplotlib.pyplot as plt +import numpy as np +from sklearn.metrics import auc, roc_curve +from sklearn.preprocessing import LabelBinarizer + +from skore.sklearn._plot.utils import ( + HelpDisplayMixin, + _ClassifierCurveDisplayMixin, + _despine_matplotlib_axis, + sample_mpl_colormap, +) + +LINESTYLE = [ + ("solid", "solid"), + ("dotted", "dotted"), + ("dashed", "dashed"), + ("dashdot", "dashdot"), + ("loosely dotted", (0, (1, 10))), + ("dotted", (0, (1, 5))), + ("densely dotted", (0, (1, 1))), + ("long dash with offset", (5, (10, 3))), + ("loosely dashed", (0, (5, 10))), + ("dashed", (0, (5, 5))), + ("densely dashed", (0, (5, 1))), + ("loosely dashdotted", (0, (3, 10, 1, 10))), + ("dashdotted", (0, (3, 5, 1, 5))), + ("densely dashdotted", (0, (3, 1, 1, 1))), + ("dashdotdotted", (0, (3, 5, 1, 5, 1, 5))), + ("loosely dashdotdotted", (0, (3, 10, 1, 10, 1, 10))), + ("densely dashdotdotted", (0, (3, 1, 1, 1, 1, 1))), +] + + +class RocCurveDisplay(HelpDisplayMixin, _ClassifierCurveDisplayMixin): + """ROC Curve visualization for comparison report. + + An instance of this class is should created by `ComparisonReport.metrics.roc()`. + You should not create an instance of this class directly. + + Parameters + ---------- + fpr : dict of list of ndarray + False positive rate. The structure is: + + - for binary classification: + - the key is the positive label. + - the value is a list of `ndarray`, each `ndarray` being the false + positive rate. + - for multiclass classification: + - the key is the class of interest in an OvR fashion. + - the value is a list of `ndarray`, each `ndarray` being the false + positive rate. + + tpr : dict of list of ndarray + True positive rate. The structure is: + + - for binary classification: + - the key is the positive label + - the value is a list of `ndarray`, each `ndarray` being the true + positive rate. + - for multiclass classification: + - the key is the class of interest in an OvR fashion. + - the value is a list of `ndarray`, each `ndarray` being the true + positive rate. + + roc_auc : dict of list of float + Area under the ROC curve. The structure is: + + - for binary classification: + - the key is the positive label + - the value is a list of `float`, each `float` being the area under + the ROC curve. + - for multiclass classification: + - the key is the class of interest in an OvR fashion. + - the value is a list of `float`, each `float` being the area under + the ROC curve. + + estimator_names : str + Name of the estimators. + + ml_task : str + Type of ML task. + + pos_label : int, float, bool or str, default=None + The class considered as positive. Only meaningful for binary classification. + + data_source : {"train", "test", "X_y"} + The data source used to compute the ROC curve. + + Attributes + ---------- + ax_ : matplotlib axes + The axes on which the ROC curve is plotted, available after calling `plot`. + + figure_ : matplotlib figure + The figure on which the ROC curve is plotted, available after calling `plot`. + + lines_ : list of matplotlib lines + The lines of the ROC curve, available after calling `plot`. + """ + + def __init__( + self, + *, + fpr, + tpr, + roc_auc, + estimator_names, + ml_task, + data_source, + pos_label=None, + ): + self.fpr = fpr + self.tpr = tpr + self.roc_auc = roc_auc + self.estimator_names = estimator_names + self.ml_task = ml_task + self.data_source = data_source + self.pos_label = pos_label + + def plot( + self, + ax=None, + *, + plot_chance_level=True, + despine=True, + ): + """Plot visualization. + + Parameters + ---------- + ax : matplotlib axes, default=None + Axes object to plot on. If `None`, a new figure and axes is created. + plot_chance_level : bool, default=True + Whether to plot the chance level. + despine : bool, default=True + Whether to remove the top and right spines from the plot. + """ + self.figure_, self.ax_ = (ax.figure, ax) if ax else plt.subplots() + self.lines_ = [] + + if self.ml_task == "binary-classification": + for report_idx, report_name in enumerate(self.estimator_names): + fpr = self.fpr[self.pos_label][report_idx] + tpr = self.tpr[self.pos_label][report_idx] + roc_auc = self.roc_auc[self.pos_label][report_idx] + + self.lines_ += self.ax_.plot( + fpr, + tpr, + alpha=0.6, + label=f"{report_name} #{report_idx + 1} (AUC = {roc_auc:0.2f})", + ) + + info_pos_label = ( + f"\n(Positive label: {self.pos_label})" + if self.pos_label is not None + else "" + ) + else: # multiclass-classification + info_pos_label = None # irrelevant for multiclass + colors = sample_mpl_colormap( + plt.cm.tab10, + 10 if len(self.estimator_names) < 10 else len(self.estimator_names), + ) + + for report_idx, report_name in enumerate(self.estimator_names): + report_color = colors[report_idx] + + for class_idx, class_ in enumerate(self.fpr): + fpr = self.fpr[class_][report_idx] + tpr = self.tpr[class_][report_idx] + roc_auc_mean = np.mean(self.roc_auc[class_]) + class_linestyle = LINESTYLE[(class_idx % len(LINESTYLE))][1] + + self.lines_ += self.ax_.plot( + fpr, + tpr, + alpha=0.6, + linestyle=class_linestyle, + color=report_color, + label=( + f"{report_name} #{report_idx + 1} - class {class_} " + f"(AUC = {roc_auc_mean:0.2f})" + ), + ) + + xlabel = "False Positive Rate" + ylabel = "True Positive Rate" + if info_pos_label: + xlabel += info_pos_label + ylabel += info_pos_label + + self.ax_.set( + xlabel=xlabel, + xlim=(-0.01, 1.01), + ylabel=ylabel, + ylim=(-0.01, 1.01), + aspect="equal", + ) + + if plot_chance_level: + self.ax_.plot( + (0, 1), + (0, 1), + label="Chance level (AUC = 0.5)", + color="k", + linestyle="--", + ) + + if despine: + _despine_matplotlib_axis(self.ax_) + + self.ax_.legend( + loc="lower right", + title=f"{self.ml_task.title()} on $\\bf{{{self.data_source}}}$ set", + ) + + @classmethod + def _from_predictions( + cls, + y_true, + y_pred, + *, + estimators, + estimator_names, + ml_task, + data_source, + pos_label=None, + drop_intermediate=True, + ): + """Private factory to create a RocCurveDisplay from predictions. + + Parameters + ---------- + y_true : list of array-like of shape (n_samples,) + True binary labels in binary classification. + + y_pred : list of array-like of shape (n_samples,) + Target scores, can either be probability estimates of the positive class, + confidence values, or non-thresholded measure of decisions (as returned by + “decision_function” on some classifiers). + + estimators : list of estimator instances + The estimators from which `y_pred` is obtained. + + estimator_names : list[str] + Name of the estimators used to plot the ROC curve. + + ml_task : {"binary-classification", "multiclass-classification"} + The machine learning task. + + data_source : {"train", "test", "X_y"} + The data source used to compute the ROC curve. + + pos_label : int, float, bool or str, default=None + The class considered as the positive class when computing the + precision and recall metrics. + + drop_intermediate : bool, default=True + Whether to drop intermediate points with identical value. + + Returns + ------- + display : RocCurveDisplay + """ + estimator_classes = [estimator.classes_ for estimator in estimators] + fpr, tpr, roc_auc = (defaultdict(list) for _ in range(3)) + pos_label_validated = cls._validate_from_predictions_params( + y_true, y_pred, ml_task=ml_task, pos_label=pos_label + ) + + if ml_task == "binary-classification": + for y_true_i, y_pred_i in zip(y_true, y_pred): + fpr_i, tpr_i, _ = roc_curve( + y_true_i, + y_pred_i, + pos_label=pos_label, + drop_intermediate=drop_intermediate, + ) + + fpr[pos_label_validated].append(fpr_i) + tpr[pos_label_validated].append(tpr_i) + roc_auc[pos_label_validated].append(auc(fpr_i, tpr_i)) + elif ml_task == "multiclass-classification": + for y_true_i, y_pred_i, estimator_classes_i in zip( + y_true, + y_pred, + estimator_classes, + ): + label_binarizer = LabelBinarizer().fit(estimator_classes_i) + y_true_onehot_i = label_binarizer.transform(y_true_i) + + for class_idx, class_ in enumerate(estimator_classes_i): + fpr_class_i, tpr_class_i, _ = roc_curve( + y_true_onehot_i[:, class_idx], + y_pred_i[:, class_idx], + pos_label=None, + drop_intermediate=drop_intermediate, + ) + + fpr[class_].append(fpr_class_i) + tpr[class_].append(tpr_class_i) + roc_auc[class_].append(auc(fpr_class_i, tpr_class_i)) + else: + raise ValueError("Only binary or multiclass classification is allowed") + + return cls( + fpr=dict(fpr), + tpr=dict(tpr), + roc_auc=dict(roc_auc), + estimator_names=estimator_names, + ml_task=ml_task, + pos_label=pos_label_validated, + data_source=data_source, + ) diff --git a/skore/src/skore/sklearn/_cross_validation/metrics_accessor.py b/skore/src/skore/sklearn/_cross_validation/metrics_accessor.py index ff9cd5611..d3957132c 100644 --- a/skore/src/skore/sklearn/_cross_validation/metrics_accessor.py +++ b/skore/src/skore/sklearn/_cross_validation/metrics_accessor.py @@ -12,6 +12,7 @@ RocCurveDisplay, ) from skore.utils._accessor import _check_supported_ml_task +from skore.utils._index import flatten_multi_index from skore.utils._parallel import Parallel, delayed from skore.utils._progress_bar import progress_decorator @@ -46,9 +47,10 @@ def report_metrics( data_source="test", scoring=None, scoring_names=None, - pos_label=None, scoring_kwargs=None, + pos_label=None, indicator_favorability=False, + flat_index=False, aggregate=None, ): """Report a set of metrics for our estimator. @@ -74,16 +76,20 @@ def report_metrics( Used to overwrite the default scoring names in the report. It should be of the same length as the `scoring` parameter. - pos_label : int, float, bool or str, default=None - The positive class. - scoring_kwargs : dict, default=None The keyword arguments to pass to the scoring functions. + pos_label : int, float, bool or str, default=None + The positive class. + indicator_favorability : bool, default=False Whether or not to add an indicator of the favorability of the metric as an extra column in the returned DataFrame. + flat_index : bool, default=False + Whether to flatten the `MultiIndex` columns. Flat index will always be lower + case, do not include spaces and remove the hash symbol to ease indexing. + aggregate : {"mean", "std"} or list of such str, default=None Function to aggregate the scores across the cross-validation splits. @@ -112,7 +118,7 @@ def report_metrics( Precision 0.94... 0.02... (↗︎) Recall 0.96... 0.02... (↗︎) """ - return self._compute_metric_scores( + results = self._compute_metric_scores( report_metric_name="report_metrics", data_source=data_source, aggregate=aggregate, @@ -122,6 +128,12 @@ def report_metrics( scoring_names=scoring_names, indicator_favorability=indicator_favorability, ) + if flat_index: + if isinstance(results.columns, pd.MultiIndex): + results.columns = flatten_multi_index(results.columns) + if isinstance(results.index, pd.MultiIndex): + results.index = flatten_multi_index(results.index) + return results @progress_decorator(description="Compute metric for each split") def _compute_metric_scores( diff --git a/skore/src/skore/sklearn/_estimator/metrics_accessor.py b/skore/src/skore/sklearn/_estimator/metrics_accessor.py index e3f77dfa2..c427e65e5 100644 --- a/skore/src/skore/sklearn/_estimator/metrics_accessor.py +++ b/skore/src/skore/sklearn/_estimator/metrics_accessor.py @@ -16,6 +16,7 @@ RocCurveDisplay, ) from skore.utils._accessor import _check_supported_ml_task +from skore.utils._index import flatten_multi_index class _MetricsAccessor(_BaseAccessor, DirNamesMixin): @@ -48,9 +49,10 @@ def report_metrics( y=None, scoring=None, scoring_names=None, - pos_label=None, scoring_kwargs=None, + pos_label=None, indicator_favorability=False, + flat_index=False, ): """Report a set of metrics for our estimator. @@ -84,16 +86,20 @@ def report_metrics( Used to overwrite the default scoring names in the report. It should be of the same length as the `scoring` parameter. - pos_label : int, float, bool or str, default=None - The positive class. - scoring_kwargs : dict, default=None The keyword arguments to pass to the scoring functions. + pos_label : int, float, bool or str, default=None + The positive class. + indicator_favorability : bool, default=False Whether or not to add an indicator of the favorability of the metric as an extra column in the returned DataFrame. + flat_index : bool, default=False + Whether to flatten the multiindex columns. Flat index will always be lower + case, do not include spaces and remove the hash symbol to ease indexing. + Returns ------- pd.DataFrame @@ -339,7 +345,13 @@ def report_metrics( names=name_index, ) - return pd.concat(scores, axis=0) + results = pd.concat(scores, axis=0) + if flat_index: + if isinstance(results.columns, pd.MultiIndex): + results.columns = flatten_multi_index(results.columns) + if isinstance(results.index, pd.MultiIndex): + results.index = flatten_multi_index(results.index) + return results def _compute_metric_scores( self, diff --git a/skore/src/skore/sklearn/_plot/precision_recall_curve.py b/skore/src/skore/sklearn/_plot/precision_recall_curve.py index 11c36a199..f810a810b 100644 --- a/skore/src/skore/sklearn/_plot/precision_recall_curve.py +++ b/skore/src/skore/sklearn/_plot/precision_recall_curve.py @@ -147,11 +147,6 @@ def plot( despine : bool, default=True Whether to remove the top and right spines from the plot. - Returns - ------- - display : PrecisionRecallCurveDisplay - Object that stores computed values. - Notes ----- The average precision (cf. :func:`~sklearn.metrics.average_precision_score`) @@ -402,7 +397,7 @@ def _from_predictions( The machine learning task. data_source : {"train", "test", "X_y"}, default=None - The data source used to compute the ROC curve. + The data source used to compute the precision recall curve. pos_label : int, float, bool or str, default=None The class considered as the positive class when computing the @@ -415,7 +410,7 @@ def _from_predictions( Returns ------- - display : :class:`~sklearn.metrics.PrecisionRecallDisplay` + display : PrecisionRecallCurveDisplay """ pos_label_validated = cls._validate_from_predictions_params( y_true, y_pred, ml_task=ml_task, pos_label=pos_label diff --git a/skore/src/skore/sklearn/_plot/prediction_error.py b/skore/src/skore/sklearn/_plot/prediction_error.py index bc7ce670f..7f50e6217 100644 --- a/skore/src/skore/sklearn/_plot/prediction_error.py +++ b/skore/src/skore/sklearn/_plot/prediction_error.py @@ -129,11 +129,6 @@ def plot( despine : bool, default=True Whether to remove the top and right spines from the plot. - Returns - ------- - display : PredictionErrorDisplay - Object that stores computed values. - Examples -------- >>> from sklearn.datasets import load_diabetes @@ -313,7 +308,7 @@ def _from_predictions( The machine learning task. data_source : {"train", "test", "X_y"}, default=None - The data source used to compute the ROC curve. + The data source used to compute the prediction error curve. subsample : float, int or None, default=1_000 Sampling the samples to be shown on the scatter plot. If `float`, @@ -329,7 +324,6 @@ def _from_predictions( Returns ------- display : PredictionErrorDisplay - Object that stores the computed values. """ random_state = check_random_state(random_state) if isinstance(subsample, numbers.Integral): diff --git a/skore/src/skore/sklearn/_plot/roc_curve.py b/skore/src/skore/sklearn/_plot/roc_curve.py index e0c352aff..25067b4a4 100644 --- a/skore/src/skore/sklearn/_plot/roc_curve.py +++ b/skore/src/skore/sklearn/_plot/roc_curve.py @@ -157,11 +157,6 @@ def plot( despine : bool, default=True Whether to remove the top and right spines from the plot. - Returns - ------- - display : :class:`~sklearn.metrics.RocCurveDisplay` - Object that stores computed values. - Examples -------- >>> from sklearn.datasets import load_breast_cancer diff --git a/skore/src/skore/utils/_index.py b/skore/src/skore/utils/_index.py new file mode 100644 index 000000000..5b6a76f66 --- /dev/null +++ b/skore/src/skore/utils/_index.py @@ -0,0 +1,41 @@ +import pandas as pd + + +def flatten_multi_index(index: pd.MultiIndex) -> pd.Index: + """Flatten a pandas MultiIndex into a single-level Index. + + Flatten a pandas `MultiIndex` into a single-level Index by joining the levels + with underscores. Empty strings are skipped when joining. Spaces are replaced by + an underscore and "#" are skipped. + + Parameters + ---------- + index : pandas.MultiIndex + The `MultiIndex` to flatten. + + Returns + ------- + pandas.Index + A flattened `Index` with non-empty levels joined by underscores. + + Examples + -------- + >>> import pandas as pd + >>> mi = pd.MultiIndex.from_tuples( + ... [('a', ''), ('b', '2')], names=['letter', 'number'] + ... ) + >>> flatten_multi_index(mi) + Index(['a', 'b_2'], dtype='object') + """ + if not isinstance(index, pd.MultiIndex): + raise ValueError("`index` must be a MultiIndex.") + + return pd.Index( + [ + "_".join(filter(bool, map(str, values))) + .replace(" ", "_") + .replace("#", "") + .lower() + for values in index + ] + ) diff --git a/skore/tests/unit/sklearn/test_comparison.py b/skore/tests/unit/sklearn/comparison/test_comparison.py similarity index 77% rename from skore/tests/unit/sklearn/test_comparison.py rename to skore/tests/unit/sklearn/comparison/test_comparison.py index 8734636d2..b6213b811 100644 --- a/skore/tests/unit/sklearn/test_comparison.py +++ b/skore/tests/unit/sklearn/comparison/test_comparison.py @@ -4,11 +4,17 @@ import joblib import pandas as pd import pytest +from numpy.testing import assert_allclose from sklearn.datasets import make_classification from sklearn.linear_model import LinearRegression, LogisticRegression from sklearn.metrics import mean_absolute_error from sklearn.model_selection import train_test_split from skore import ComparisonReport, EstimatorReport +from skore.sklearn._comparison.metrics_accessor import ( + PrecisionRecallCurveDisplay, + PredictionErrorDisplay, + RocCurveDisplay, +) @pytest.fixture @@ -536,6 +542,34 @@ def test_comparison_report_custom_metric_X_y(binary_classification_model): pd.testing.assert_frame_equal(result, expected) +def test_cross_validation_report_flat_index(binary_classification_model): + """Check that the index is flattened when `flat_index` is True. + + Since `pos_label` is None, then by default a MultiIndex would be returned. + Here, we force to have a single-index by passing `flat_index=True`. + """ + estimator, X_train, X_test, y_train, y_test = binary_classification_model + report_1 = EstimatorReport( + estimator, X_train=X_train, y_train=y_train, X_test=X_test, y_test=y_test + ) + report_2 = EstimatorReport( + estimator, X_train=X_train, y_train=y_train, X_test=X_test, y_test=y_test + ) + report = ComparisonReport({"report_1": report_1, "report_2": report_2}) + result = report.metrics.report_metrics(flat_index=True) + assert result.shape == (6, 2) + assert isinstance(result.index, pd.Index) + assert result.index.tolist() == [ + "precision_0", + "precision_1", + "recall_0", + "recall_1", + "roc_auc", + "brier_score", + ] + assert result.columns.tolist() == ["report_1", "report_2"] + + def test_estimator_report_report_metrics_indicator_favorability( binary_classification_model, ): @@ -557,3 +591,124 @@ def test_estimator_report_report_metrics_indicator_favorability( assert indicator["Recall"].tolist() == ["(↗︎)", "(↗︎)"] assert indicator["ROC AUC"].tolist() == ["(↗︎)"] assert indicator["Brier score"].tolist() == ["(↘︎)"] + + +@pytest.mark.parametrize("plot_data_source", ["test", "X_y"]) +@pytest.mark.parametrize( + "plot_ml_task, plot_name, plot_cls, plot_attributes", + [ + ( + "binary_classification", + "roc", + RocCurveDisplay, + { + "fpr": {1: [[0, 0, 0, 1], [0, 0, 0, 1]]}, + "tpr": {1: [[0, 0.1, 1, 1], [0, 0.1, 1, 1]]}, + "roc_auc": {1: [1, 1]}, + }, + ), + ( + "binary_classification", + "precision_recall", + PrecisionRecallCurveDisplay, + { + "precision": { + 1: [ + [0.4, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], + [0.4, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], + ], + }, + "recall": { + 1: [ + [1, 1, 0.9, 0.8, 0.7, 0.6, 0.5, 0.4, 0.3, 0.2, 0.1, 0], + [1, 1, 0.9, 0.8, 0.7, 0.6, 0.5, 0.4, 0.3, 0.2, 0.1, 0], + ] + }, + "average_precision": {1: [0.99, 0.99]}, + }, + ), + ( + "regression", + "prediction_error", + PredictionErrorDisplay, + { + "y_true": [ + ( + [0, 1, 1, 1, 0, 1, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0] + + [1, 1, 1, 0] + ), + ( + [0, 1, 1, 1, 0, 1, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0] + + [1, 1, 1, 0] + ), + ], + "y_pred": [ + ( + [0.32, 1.25, 0.94, 0.77, -0.58, 0.89, 0.12, 0.51, 0.70, 0.52] + + [0.44, 0.14, 0.15, -0.13, -0.27, 0.24, 0.90, 0.22, 0.04] + + [-0.18, 0.20, 0.66, 0.99, 0.70, -0.03] + ), + ( + [0.32, 1.25, 0.94, 0.77, -0.58, 0.89, 0.12, 0.51, 0.70, 0.52] + + [0.44, 0.14, 0.15, -0.13, -0.27, 0.24, 0.90, 0.22, 0.04] + + [-0.18, 0.20, 0.66, 0.99, 0.70, -0.03] + ), + ], + }, + ), + ], +) +def test_comparison_report_plots( + plot_data_source, + plot_ml_task, + plot_name, + plot_cls, + plot_attributes, + binary_classification_model, + regression_model, +): + estimator, X_train, X_test, y_train, y_test = ( + binary_classification_model + if plot_ml_task == "binary_classification" + else regression_model + ) + estimator_report = EstimatorReport( + estimator, + X_train=X_train, + y_train=y_train, + X_test=X_test, + y_test=y_test, + ) + + comp = ComparisonReport([estimator_report, estimator_report]) + + if plot_data_source == "X_y": + arguments = {"data_source": plot_data_source, "X": X_test, "y": y_test} + else: + arguments = {"data_source": plot_data_source} + + # Ensure display object is available + display = getattr(comp.metrics, plot_name)(**arguments) + + # Ensure display object is of good type + assert isinstance(display, plot_cls) + + # Ensure data source is well set + assert display.data_source == plot_data_source + + # Ensure all attributes to test are well set + for attribute, value in plot_attributes.items(): + display_attribute_value = getattr(display, attribute) + + if isinstance(value, dict): + for k, v in value.items(): + assert isinstance(display_attribute_value, dict) + assert k in display_attribute_value + assert_allclose(display_attribute_value[k], v, atol=1e-2) + elif isinstance(value, list): + assert_allclose(display_attribute_value, value, atol=1e-2) + else: + raise NotImplementedError + + # Ensure plot is callable + display.plot() diff --git a/skore/tests/unit/sklearn/comparison/test_precision_recall_curve_display.py b/skore/tests/unit/sklearn/comparison/test_precision_recall_curve_display.py new file mode 100644 index 000000000..f9261b863 --- /dev/null +++ b/skore/tests/unit/sklearn/comparison/test_precision_recall_curve_display.py @@ -0,0 +1,188 @@ +from matplotlib.axes import Axes +from matplotlib.lines import Line2D +from numpy import array +from numpy.testing import assert_equal +from pytest import fixture +from sklearn.dummy import DummyClassifier +from skore.sklearn._comparison.precision_recall_curve_display import ( + PrecisionRecallCurveDisplay, +) + + +@fixture +def binary_classification_display(): + y_true = (array((0, 1)), array((0, 1))) + y_pred = (array((0.2, 0.8)), array((0.8, 0.2))) + estimators = ( + DummyClassifier().fit((0, 1), (0, 1)), + DummyClassifier().fit((0, 1), (0, 1)), + ) + + return PrecisionRecallCurveDisplay._from_predictions( + y_true=y_true, + y_pred=y_pred, + estimators=estimators, + estimator_names=["BC-E1", "BC-E2"], + ml_task="binary-classification", + data_source="test", + ) + + +@fixture +def multiclass_classification_display(): + y_true = (array((0, 1, 2)), array((0, 1, 2))) + y_pred = ( + array(((0.8, 0.2, 0.0), (0.0, 0.8, 0.2), (0.2, 0.0, 0.8))), + array(((0.0, 0.2, 0.8), (0.8, 0.0, 0.2), (0.2, 0.8, 0.0))), + ) + estimators = ( + DummyClassifier().fit((0, 1, 2), (0, 1, 2)), + DummyClassifier().fit((0, 1, 2), (0, 1, 2)), + ) + + return PrecisionRecallCurveDisplay._from_predictions( + y_true=y_true, + y_pred=y_pred, + estimators=estimators, + estimator_names=["MC-E1", "MC-E2"], + ml_task="multiclass-classification", + data_source="test", + ) + + +class TestPrecisionRecallCurveDisplay: + def test_from_predictions_binary_classification( + self, binary_classification_display + ): + display = binary_classification_display + + assert_equal(display.precision, {1: [(0.5, 1, 1), (0.5, 0, 1)]}) + assert_equal(display.recall, {1: [(1, 1, 0), (1, 0, 0)]}) + assert_equal(display.average_precision, {1: [1, 0.5]}) + assert display.estimator_names == ["BC-E1", "BC-E2"] + assert display.ml_task == "binary-classification" + assert display.pos_label == 1 + assert display.data_source == "test" + + def test_from_predictions_multiclass_classification( + self, multiclass_classification_display + ): + display = multiclass_classification_display + + assert_equal( + display.precision, + { + 0: [(1 / 3, 1, 1), (1 / 3, 0, 0, 1)], + 1: [(1 / 3, 1, 1), (1 / 3, 0, 0, 1)], + 2: [(1 / 3, 1, 1), (1 / 3, 0, 0, 1)], + }, + ) + assert_equal( + display.recall, + { + 0: [(1, 1, 0), (1, 0, 0, 0)], + 1: [(1, 1, 0), (1, 0, 0, 0)], + 2: [(1, 1, 0), (1, 0, 0, 0)], + }, + ) + assert_equal( + display.average_precision, + {0: (1, 1 / 3), 1: (1, 1 / 3), 2: (1, 1 / 3)}, + ) + assert display.estimator_names == ["MC-E1", "MC-E2"] + assert display.ml_task == "multiclass-classification" + assert display.pos_label is None + assert display.data_source == "test" + + def test_plot_binary_classification(self, tmp_path, binary_classification_display): + display = binary_classification_display + display.plot() + + # Test `lines_` attribute + assert isinstance(display.lines_, list) + assert len(display.lines_) == 2 + assert all(isinstance(line, Line2D) for line in display.lines_) + assert_equal(display.lines_[0].get_xdata(), (1.0, 1.0, 0.0)) + assert_equal(display.lines_[1].get_xdata(), (1.0, 0.0, 0.0)) + assert_equal(display.lines_[0].get_ydata(), (0.5, 1.0, 1.0)) + assert_equal(display.lines_[1].get_ydata(), (0.5, 0.0, 1.0)) + assert display.lines_[0].get_label() == "BC-E1 #1 (AP = 1.00)" + assert display.lines_[1].get_label() == "BC-E2 #2 (AP = 0.50)" + assert display.lines_[0].get_color() != display.lines_[1].get_color() + + # Test `ax_` attribute, its lines, legend and title + assert isinstance(display.ax_, Axes) + assert display.ax_.lines[:2] == display.lines_ + assert ( + display.ax_.get_legend().get_title().get_text() + == "Binary-Classification on $\\bf{test}$ set" + ) + assert display.ax_.get_xlabel() == "Recall\n(Positive label: 1)" + assert display.ax_.get_ylabel() == "Precision\n(Positive label: 1)" + assert display.ax_.get_adjustable() == "box" + assert display.ax_.get_aspect() in ("equal", 1.0) + assert display.ax_.get_xlim() == display.ax_.get_ylim() == (-0.01, 1.01) + + # Test `figure_` attribute + assert display.ax_.figure == display.figure_ + + def test_plot_multiclass_classification( + self, tmp_path, multiclass_classification_display + ): + display = multiclass_classification_display + display.plot() + + # Test `lines_` attribute + assert isinstance(display.lines_, list) + assert len(display.lines_) == 6 + assert all(isinstance(line, Line2D) for line in display.lines_) + assert_equal(display.lines_[0].get_xdata(), (1.0, 1.0, 0.0)) + assert_equal(display.lines_[1].get_xdata(), (1.0, 1.0, 0.0)) + assert_equal(display.lines_[2].get_xdata(), (1.0, 1.0, 0.0)) + assert_equal(display.lines_[3].get_xdata(), (1.0, 0.0, 0.0, 0.0)) + assert_equal(display.lines_[4].get_xdata(), (1.0, 0.0, 0.0, 0.0)) + assert_equal(display.lines_[5].get_xdata(), (1.0, 0.0, 0.0, 0.0)) + assert_equal(display.lines_[0].get_ydata(), (1 / 3, 1.0, 1.0)) + assert_equal(display.lines_[1].get_ydata(), (1 / 3, 1.0, 1.0)) + assert_equal(display.lines_[2].get_ydata(), (1 / 3, 1.0, 1.0)) + assert_equal(display.lines_[3].get_ydata(), (1 / 3, 0.0, 0.0, 1.0)) + assert_equal(display.lines_[4].get_ydata(), (1 / 3, 0.0, 0.0, 1.0)) + assert_equal(display.lines_[5].get_ydata(), (1 / 3, 0.0, 0.0, 1.0)) + assert display.lines_[0].get_label() == "MC-E1 #1 - class 0 (AP = 0.67)" + assert display.lines_[1].get_label() == "MC-E1 #1 - class 1 (AP = 0.67)" + assert display.lines_[2].get_label() == "MC-E1 #1 - class 2 (AP = 0.67)" + assert display.lines_[3].get_label() == "MC-E2 #2 - class 0 (AP = 0.67)" + assert display.lines_[4].get_label() == "MC-E2 #2 - class 1 (AP = 0.67)" + assert display.lines_[5].get_label() == "MC-E2 #2 - class 2 (AP = 0.67)" + assert display.lines_[0].get_color() != display.lines_[3].get_color() + assert ( + display.lines_[0].get_color() + == display.lines_[1].get_color() + == display.lines_[2].get_color() + ) + assert ( + display.lines_[3].get_color() + == display.lines_[4].get_color() + == display.lines_[5].get_color() + ) + assert display.lines_[0].get_linestyle() != display.lines_[1].get_linestyle() + assert display.lines_[1].get_linestyle() != display.lines_[2].get_linestyle() + assert display.lines_[0].get_linestyle() == display.lines_[3].get_linestyle() + assert display.lines_[1].get_linestyle() == display.lines_[4].get_linestyle() + assert display.lines_[2].get_linestyle() == display.lines_[5].get_linestyle() + + # Test `ax_` attribute, its lines, legend and title + assert isinstance(display.ax_, Axes) + assert display.ax_.lines[:6] == display.lines_ + assert ( + display.ax_.get_legend().get_title().get_text() + == "Multiclass-Classification on $\\bf{test}$ set" + ) + assert display.ax_.get_xlabel() == "Recall" + assert display.ax_.get_ylabel() == "Precision" + assert display.ax_.get_adjustable() == "box" + assert display.ax_.get_aspect() in ("equal", 1.0) + assert display.ax_.get_xlim() == display.ax_.get_ylim() == (-0.01, 1.01) + + # Test `figure_` attribute + assert display.ax_.figure == display.figure_ diff --git a/skore/tests/unit/sklearn/comparison/test_prediction_error_display.py b/skore/tests/unit/sklearn/comparison/test_prediction_error_display.py new file mode 100644 index 000000000..8fdcebe7a --- /dev/null +++ b/skore/tests/unit/sklearn/comparison/test_prediction_error_display.py @@ -0,0 +1,122 @@ +from matplotlib.axes import Axes +from matplotlib.collections import PathCollection +from matplotlib.lines import Line2D +from numpy import array, array_equal +from numpy.testing import assert_equal +from pytest import fixture +from skore.sklearn._comparison.prediction_error_display import PredictionErrorDisplay + + +@fixture +def regression_display(): + y_true = (array((-1, 0, 1)), array((-1, 0, 1))) + y_pred = (array((0, 0.2, 0.8)), array((0.8, 0.2, 0))) + + return PredictionErrorDisplay._from_predictions( + y_true=y_true, + y_pred=y_pred, + estimator_names=["R-E1", "R-E2"], + ml_task="regression", + data_source="test", + ) + + +class TestPredictionErrorDisplay: + def test_from_predictions(self, regression_display): + display = regression_display + + assert_equal(display.y_true, [(-1, 0, 1), (-1, 0, 1)]) + assert_equal(display.y_pred, [(0, 0.2, 0.8), (0.8, 0.2, 0)]) + assert display.estimator_names == ["R-E1", "R-E2"] + assert display.data_source == "test" + + def test_plot_residual_vs_predicted(self, tmp_path, regression_display): + display = regression_display + display.plot(kind="residual_vs_predicted") + + # Test `line_` attribute + assert isinstance(display.line_, Line2D) + assert_equal(display.line_.get_xdata(), (0.0, 0.8)) + assert_equal(display.line_.get_ydata(), (0.0, 0.0)) + assert display.line_.get_label() == "Perfect predictions" + assert display.line_.get_color() == "black" + + # Test `scatters_` attribute + assert isinstance(display.scatters_, list) + assert len(display.scatters_) == 2 + assert all(isinstance(scatter, PathCollection) for scatter in display.scatters_) + assert_equal( + display.scatters_[0].get_offsets().data, + ((0.0, (-1.0 - 0)), (0.2, (0.0 - 0.2)), (0.8, (1 - 0.8))), + ) + assert_equal( + display.scatters_[1].get_offsets().data, + ((0.8, (-1 - 0.8)), (0.2, (0 - 0.2)), (0.0, (1.0 - 0.0))), + ) + assert display.scatters_[0].get_label() == "R-E1 #1" + assert display.scatters_[1].get_label() == "R-E2 #2" + assert not array_equal( + display.scatters_[0].get_facecolor(), display.scatters_[1].get_facecolor() + ) + + # Test `ax_` attribute, its scatters, legend and title + assert isinstance(display.ax_, Axes) + assert display.ax_.collections[:2] == display.scatters_ + assert ( + display.ax_.get_legend().get_title().get_text() + == "Regression on $\\bf{test}$ set" + ) + assert display.ax_.get_xlabel() == "Predicted values" + assert display.ax_.get_ylabel() == "Residuals (actual - predicted)" + assert display.ax_.get_adjustable() == "box" + assert display.ax_.get_aspect() in ("equal", 1.0) + assert display.ax_.get_xlim() == (0.00, 0.80) + assert display.ax_.get_ylim() == (-1.80, 1.00) + + # Test `figure_` attribute + assert display.ax_.figure == display.figure_ + + def test_plot_actual_vs_predicted(self, tmp_path, regression_display): + display = regression_display + display.plot(kind="actual_vs_predicted") + + # Test `line_` attribute + assert isinstance(display.line_, Line2D) + assert_equal(display.line_.get_xdata(), (-1.0, 1.0)) + assert_equal(display.line_.get_ydata(), (-1.0, 1.0)) + assert display.line_.get_label() == "Perfect predictions" + assert display.line_.get_color() == "black" + + # Test `scatters_` attribute + assert isinstance(display.scatters_, list) + assert len(display.scatters_) == 2 + assert all(isinstance(scatter, PathCollection) for scatter in display.scatters_) + assert_equal( + display.scatters_[0].get_offsets().data, + ((0.0, -1.0), (0.2, 0.0), (0.8, 1.0)), + ) + assert_equal( + display.scatters_[1].get_offsets().data, + ((0.8, -1.0), (0.2, 0.0), (0.0, 1.0)), + ) + assert display.scatters_[0].get_label() == "R-E1 #1" + assert display.scatters_[1].get_label() == "R-E2 #2" + assert not array_equal( + display.scatters_[0].get_facecolor(), display.scatters_[1].get_facecolor() + ) + + # Test `ax_` attribute, its scatters, legend and title + assert isinstance(display.ax_, Axes) + assert display.ax_.collections[:2] == display.scatters_ + assert ( + display.ax_.get_legend().get_title().get_text() + == "Regression on $\\bf{test}$ set" + ) + assert display.ax_.get_xlabel() == "Predicted values" + assert display.ax_.get_ylabel() == "Actual values" + assert display.ax_.get_adjustable() == "box" + assert display.ax_.get_aspect() in ("equal", 1.0) + assert display.ax_.get_xlim() == display.ax_.get_ylim() == (-1.00, 1.00) + + # Test `figure_` attribute + assert display.ax_.figure == display.figure_ diff --git a/skore/tests/unit/sklearn/comparison/test_roc_curve_display.py b/skore/tests/unit/sklearn/comparison/test_roc_curve_display.py new file mode 100644 index 000000000..4ab0460d4 --- /dev/null +++ b/skore/tests/unit/sklearn/comparison/test_roc_curve_display.py @@ -0,0 +1,183 @@ +from matplotlib.axes import Axes +from matplotlib.lines import Line2D +from numpy import array +from numpy.testing import assert_equal +from pytest import fixture +from sklearn.dummy import DummyClassifier +from skore.sklearn._comparison.roc_curve_display import RocCurveDisplay + + +@fixture +def binary_classification_display(): + y_true = (array((0, 1)), array((0, 1))) + y_pred = (array((0.2, 0.8)), array((0.8, 0.2))) + estimators = ( + DummyClassifier().fit((0, 1), (0, 1)), + DummyClassifier().fit((0, 1), (0, 1)), + ) + + return RocCurveDisplay._from_predictions( + y_true=y_true, + y_pred=y_pred, + estimators=estimators, + estimator_names=["BC-E1", "BC-E2"], + ml_task="binary-classification", + data_source="test", + ) + + +@fixture +def multiclass_classification_display(): + y_true = (array((0, 1, 2)), array((0, 1, 2))) + y_pred = ( + array(((0.8, 0.2, 0.0), (0.0, 0.8, 0.2), (0.2, 0.0, 0.8))), + array(((0.0, 0.2, 0.8), (0.8, 0.0, 0.2), (0.2, 0.8, 0.0))), + ) + estimators = ( + DummyClassifier().fit((0, 1, 2), (0, 1, 2)), + DummyClassifier().fit((0, 1, 2), (0, 1, 2)), + ) + + return RocCurveDisplay._from_predictions( + y_true=y_true, + y_pred=y_pred, + estimators=estimators, + estimator_names=["MC-E1", "MC-E2"], + ml_task="multiclass-classification", + data_source="test", + ) + + +class TestRocCurveDisplay: + def test_from_predictions_binary_classification( + self, binary_classification_display + ): + display = binary_classification_display + + assert_equal(display.fpr, {1: [(0, 0, 1), (0, 1, 1)]}) + assert_equal(display.tpr, {1: [(0, 1, 1), (0, 0, 1)]}) + assert_equal(display.roc_auc, {1: [1, 0]}) + assert display.estimator_names == ["BC-E1", "BC-E2"] + assert display.ml_task == "binary-classification" + assert display.pos_label == 1 + assert display.data_source == "test" + + def test_from_predictions_multiclass_classification( + self, multiclass_classification_display + ): + display = multiclass_classification_display + + assert_equal( + display.fpr, + { + 0: [(0.0, 0.0, 1.0), (0.0, 0.5, 1.0, 1.0)], + 1: [(0.0, 0.0, 1.0), (0.0, 0.5, 1.0, 1.0)], + 2: [(0.0, 0.0, 1.0), (0.0, 0.5, 1.0, 1.0)], + }, + ) + assert_equal( + display.tpr, + { + 0: [(0.0, 1.0, 1.0), (0.0, 0.0, 0.0, 1.0)], + 1: [(0.0, 1.0, 1.0), (0.0, 0.0, 0.0, 1.0)], + 2: [(0.0, 1.0, 1.0), (0.0, 0.0, 0.0, 1.0)], + }, + ) + assert_equal(display.roc_auc, {0: [1.0, 0.0], 1: [1.0, 0.0], 2: [1.0, 0.0]}) + assert display.estimator_names, ["MC-E1", "MC-E2"] + assert display.ml_task == "multiclass-classification" + assert display.pos_label is None + assert display.data_source == "test" + + def test_plot_binary_classification(self, tmp_path, binary_classification_display): + display = binary_classification_display + display.plot() + + # Test `lines_` attribute + assert isinstance(display.lines_, list) + assert len(display.lines_) == 2 + assert all(isinstance(line, Line2D) for line in display.lines_) + assert_equal(display.lines_[0].get_xdata(), (0.0, 0.0, 1.0)) + assert_equal(display.lines_[1].get_xdata(), (0.0, 1.0, 1.0)) + assert_equal(display.lines_[0].get_ydata(), (0.0, 1.0, 1.0)) + assert_equal(display.lines_[1].get_ydata(), (0.0, 0.0, 1.0)) + assert display.lines_[0].get_label() == "BC-E1 #1 (AUC = 1.00)" + assert display.lines_[1].get_label() == "BC-E2 #2 (AUC = 0.00)" + assert display.lines_[0].get_color() != display.lines_[1].get_color() + + # Test `ax_` attribute, its lines, legend and title + assert isinstance(display.ax_, Axes) + assert display.ax_.lines[:2] == display.lines_ + assert ( + display.ax_.get_legend().get_title().get_text() + == "Binary-Classification on $\\bf{test}$ set" + ) + assert display.ax_.get_xlabel() == "False Positive Rate\n(Positive label: 1)" + assert display.ax_.get_ylabel() == "True Positive Rate\n(Positive label: 1)" + assert display.ax_.get_adjustable() == "box" + assert display.ax_.get_aspect() in ("equal", 1.0) + assert display.ax_.get_xlim() == display.ax_.get_ylim() == (-0.01, 1.01) + + # Test `figure_` attribute + assert display.ax_.figure == display.figure_ + + def test_plot_multiclass_classification( + self, tmp_path, multiclass_classification_display + ): + display = multiclass_classification_display + display.plot() + + # Test `lines_` attribute + assert isinstance(display.lines_, list) + assert len(display.lines_) == 6 + assert all(isinstance(line, Line2D) for line in display.lines_) + assert_equal(display.lines_[0].get_xdata(), (0.0, 0.0, 1.0)) + assert_equal(display.lines_[1].get_xdata(), (0.0, 0.0, 1.0)) + assert_equal(display.lines_[2].get_xdata(), (0.0, 0.0, 1.0)) + assert_equal(display.lines_[3].get_xdata(), (0.0, 0.5, 1.0, 1.0)) + assert_equal(display.lines_[4].get_xdata(), (0.0, 0.5, 1.0, 1.0)) + assert_equal(display.lines_[5].get_xdata(), (0.0, 0.5, 1.0, 1.0)) + assert_equal(display.lines_[0].get_ydata(), (0.0, 1.0, 1.0)) + assert_equal(display.lines_[1].get_ydata(), (0.0, 1.0, 1.0)) + assert_equal(display.lines_[2].get_ydata(), (0.0, 1.0, 1.0)) + assert_equal(display.lines_[3].get_ydata(), (0.0, 0.0, 0.0, 1.0)) + assert_equal(display.lines_[4].get_ydata(), (0.0, 0.0, 0.0, 1.0)) + assert_equal(display.lines_[5].get_ydata(), (0.0, 0.0, 0.0, 1.0)) + assert display.lines_[0].get_label() == "MC-E1 #1 - class 0 (AUC = 0.50)" + assert display.lines_[1].get_label() == "MC-E1 #1 - class 1 (AUC = 0.50)" + assert display.lines_[2].get_label() == "MC-E1 #1 - class 2 (AUC = 0.50)" + assert display.lines_[3].get_label() == "MC-E2 #2 - class 0 (AUC = 0.50)" + assert display.lines_[4].get_label() == "MC-E2 #2 - class 1 (AUC = 0.50)" + assert display.lines_[5].get_label() == "MC-E2 #2 - class 2 (AUC = 0.50)" + assert display.lines_[0].get_color() != display.lines_[3].get_color() + assert ( + display.lines_[0].get_color() + == display.lines_[1].get_color() + == display.lines_[2].get_color() + ) + assert ( + display.lines_[3].get_color() + == display.lines_[4].get_color() + == display.lines_[5].get_color() + ) + assert display.lines_[0].get_linestyle() != display.lines_[1].get_linestyle() + assert display.lines_[1].get_linestyle() != display.lines_[2].get_linestyle() + assert display.lines_[0].get_linestyle() == display.lines_[3].get_linestyle() + assert display.lines_[1].get_linestyle() == display.lines_[4].get_linestyle() + assert display.lines_[2].get_linestyle() == display.lines_[5].get_linestyle() + + # Test `ax_` attribute, its lines, legend and title + assert isinstance(display.ax_, Axes) + assert display.ax_.lines[:6] == display.lines_ + assert ( + display.ax_.get_legend().get_title().get_text() + == "Multiclass-Classification on $\\bf{test}$ set" + ) + assert display.ax_.get_xlabel() == "False Positive Rate" + assert display.ax_.get_ylabel() == "True Positive Rate" + assert display.ax_.get_adjustable() == "box" + assert display.ax_.get_aspect() in ("equal", 1.0) + assert display.ax_.get_xlim() == display.ax_.get_ylim() == (-0.01, 1.01) + + # Test `figure_` attribute + assert display.ax_.figure == display.figure_ diff --git a/skore/tests/unit/sklearn/test_cross_validation.py b/skore/tests/unit/sklearn/test_cross_validation.py index 3c246a881..c26cae511 100644 --- a/skore/tests/unit/sklearn/test_cross_validation.py +++ b/skore/tests/unit/sklearn/test_cross_validation.py @@ -219,6 +219,31 @@ def test_cross_validation_report_pickle(tmp_path, binary_classification_data): joblib.dump(report, tmp_path / "report.joblib") +def test_cross_validation_report_flat_index(binary_classification_data): + """Check that the index is flattened when `flat_index` is True. + + Since `pos_label` is None, then by default a MultiIndex would be returned. + Here, we force to have a single-index by passing `flat_index=True`. + """ + estimator, X, y = binary_classification_data + report = CrossValidationReport(estimator, X=X, y=y, cv_splitter=2) + result = report.metrics.report_metrics(flat_index=True) + assert result.shape == (6, 2) + assert isinstance(result.index, pd.Index) + assert result.index.tolist() == [ + "precision_0", + "precision_1", + "recall_0", + "recall_1", + "roc_auc", + "brier_score", + ] + assert result.columns.tolist() == [ + "randomforestclassifier_split_0", + "randomforestclassifier_split_1", + ] + + ######################################################################################## # Check the plot methods ######################################################################################## diff --git a/skore/tests/unit/sklearn/test_estimator.py b/skore/tests/unit/sklearn/test_estimator.py index edee9bac3..0c4585b6d 100644 --- a/skore/tests/unit/sklearn/test_estimator.py +++ b/skore/tests/unit/sklearn/test_estimator.py @@ -350,6 +350,28 @@ def test_estimator_report_pickle(binary_classification_data): joblib.dump(report, stream) +def test_estimator_report_flat_index(binary_classification_data): + """Check that the index is flattened when `flat_index` is True. + + Since `pos_label` is None, then by default a MultiIndex would be returned. + Here, we force to have a single-index by passing `flat_index=True`. + """ + estimator, X_test, y_test = binary_classification_data + report = EstimatorReport(estimator, X_test=X_test, y_test=y_test) + result = report.metrics.report_metrics(flat_index=True) + assert result.shape == (6, 1) + assert isinstance(result.index, pd.Index) + assert result.index.tolist() == [ + "precision_0", + "precision_1", + "recall_0", + "recall_1", + "roc_auc", + "brier_score", + ] + assert result.columns.tolist() == ["RandomForestClassifier"] + + ######################################################################################## # Check the plot methods ######################################################################################## diff --git a/skore/tests/unit/utils/test_index.py b/skore/tests/unit/utils/test_index.py new file mode 100644 index 000000000..1a305b23b --- /dev/null +++ b/skore/tests/unit/utils/test_index.py @@ -0,0 +1,63 @@ +import pandas as pd +import pytest +from skore.utils._index import flatten_multi_index + + +@pytest.mark.parametrize( + "input_tuples, names, expected_values", + [ + pytest.param( + [("A", 1), ("B", 2)], ["letter", "number"], ["a_1", "b_2"], id="basic" + ), + pytest.param( + [("A", 1, "X"), ("B", 2, "Y")], + ["letter", "number", "symbol"], + ["a_1_x", "b_2_y"], + id="multiple_levels", + ), + pytest.param( + [("A", None), (None, 2)], + ["letter", "number"], + ["a_nan", "nan_2.0"], + id="none_values", + ), + pytest.param( + [("A@B", "1#2"), ("C&D", "3$4")], + ["letter", "number"], + ["a@b_12", "c&d_3$4"], + id="special_chars", + ), + pytest.param([], ["letter", "number"], [], id="empty"), + pytest.param( + [("Hello World", "A B"), ("Space Test", "X Y")], + ["text", "more"], + ["hello_world_a_b", "space_test_x_y"], + id="spaces", + ), + pytest.param( + [("A#B#C", "1#2#3"), ("X#Y", "5#6")], + ["text", "numbers"], + ["abc_123", "xy_56"], + id="hash_symbols", + ), + pytest.param( + [("UPPER", "CASE"), ("MiXeD", "cAsE")], + ["text", "type"], + ["upper_case", "mixed_case"], + id="case_sensitivity", + ), + ], +) +def test_flatten_multi_index(input_tuples, names, expected_values): + """Test flatten_multi_index with various input cases.""" + mi = pd.MultiIndex.from_tuples(input_tuples, names=names) + result = flatten_multi_index(mi) + expected = pd.Index(expected_values) + pd.testing.assert_index_equal(result, expected) + + +def test_flatten_multi_index_invalid_input(): + """Test that non-MultiIndex input raises ValueError.""" + simple_index = pd.Index(["a", "b"]) + with pytest.raises(ValueError, match="`index` must be a MultiIndex."): + flatten_multi_index(simple_index)