Skip to content

Unavoidable "y_true and y_pred contain different number of classes" error inside a CV loop #11777

Open
@gwerbin

Description

@gwerbin

Description

During cross-validation on a multi-class problem, it's technically possible to have classes present in the test data that don't appear in the training data.

Steps/Code to Reproduce

import numpy as np
from sklearn.metrics import make_scorer, log_loss
from sklearn.model_selection import RandomizedSearchCV, StratifiedKFold
from sklearn.naive_bayes import BernoulliNB

rs = np.random.RandomState(1389057)

y = [
    'cow',
    'hedgehog',
    'fox',
    'fox',
    'hedgehog',
    'fox',
    'hedgehog',
    'cow',
    'cow',
    'fox'
]

x = rs.normal([0, 0], [1, 1], size=(len(y), 2))

model = BernoulliNB()

cv = StratifiedKFold(4, shuffle=True, random_state=rs)

param_dist = {
    'alpha': np.logspace(np.log(0.1), np.log(1), 20)
}

search = RandomizedSearchCV(model, param_dist, 5,
                            scoring=make_scorer(log_loss, needs_proba=True), cv=cv)

search.fit(x, y)

Expected Results

Either:

  1. Predicted classes from predict_proba are aligned with classes in the full training data, not just the in-fold subset.
  2. Classes not in the training data are ignored in the test data.

Actual Results

Predicted classes from predict_proba are aligned with classes in the in-fold subset only, but classes not in the training data are still used in the test data, causing the error.

I understand that this is normatively "correct" behavior, but it makes it hard/impossible to use in cross-validation with the existing APIs.

From my perspective, the best solution would be to have RandomizedSearchCV pass a labels=self.classes_ argument to its scorer. I'm not sure how well that generalizes.

Versions

Linux-3.10.0-514.26.2.el7.x86_64-x86_64-with-redhat-7.3-Maipo
Python 3.6.6 |Anaconda, Inc.| (default, Jun 28 2018, 17:14:51) [GCC 7.2.0]
NumPy 1.15.0
SciPy 1.1.0
Scikit-Learn 0.19.1

Metadata

Metadata

Assignees

No one assigned

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions