|
| 1 | +""" |
| 2 | +========================= |
| 3 | +Multilabel classification |
| 4 | +========================= |
| 5 | +
|
| 6 | +This example simulates a multi-label document classification problem. The |
| 7 | +dataset is generated randomly based on the following process: |
| 8 | +
|
| 9 | + - pick the number of labels: n ~ Poisson(n_labels) |
| 10 | + - n times, choose a class c: c ~ Multinomial(theta) |
| 11 | + - pick the document length: k ~ Poisson(length) |
| 12 | + - k times, choose a word: w ~ Multinomial(theta_c) |
| 13 | +
|
| 14 | +In the above process, rejection sampling is used to make sure that |
| 15 | +n is never zero or more than 2, and that the document length |
| 16 | +is never zero. Likewise, we reject classes which have already been chosen. |
| 17 | +The documents that are assigned to both classes are plotted surrounded by |
| 18 | +two colored circles. |
| 19 | +
|
| 20 | +The classification is performed by projecting to the first two principal |
| 21 | +components for visualisation purposes, followed by using the |
| 22 | +:class:`sklearn.multiclass.OneVsRestClassifier` metaclassifier using two SVCs |
| 23 | +with linear kernels to learn a discriminative model for each class. |
| 24 | +""" |
| 25 | +print __doc__ |
| 26 | + |
| 27 | +import numpy as np |
| 28 | +import matplotlib.pylab as pl |
| 29 | + |
| 30 | +from sklearn.datasets import make_multilabel_classification |
| 31 | +from sklearn.multiclass import OneVsRestClassifier |
| 32 | +from sklearn.svm import SVC |
| 33 | +from sklearn.decomposition import PCA |
| 34 | + |
| 35 | + |
| 36 | +def plot_hyperplane(clf, min_x, max_x, linestyle, label): |
| 37 | + # get the separating hyperplane |
| 38 | + w = clf.coef_[0] |
| 39 | + a = -w[0] / w[1] |
| 40 | + xx = np.linspace(min_x, max_x) |
| 41 | + yy = a * xx - (clf.intercept_[0]) / w[1] |
| 42 | + pl.plot(xx, yy, linestyle, label=label) |
| 43 | + |
| 44 | + |
| 45 | +X, Y = make_multilabel_classification(n_classes=2, n_labels=1, random_state=42) |
| 46 | +X = PCA(n_components=2).fit_transform(X) |
| 47 | +min_x = np.min(X[:, 0]) |
| 48 | +max_x = np.max(X[:, 0]) |
| 49 | + |
| 50 | +classif = OneVsRestClassifier(SVC(kernel='linear')) |
| 51 | +classif.fit(X, Y) |
| 52 | + |
| 53 | +pl.figure() |
| 54 | +pl.title('Multilabel classification example') |
| 55 | +pl.xlabel('First principal component') |
| 56 | +pl.ylabel('Second principal component') |
| 57 | + |
| 58 | +zero_class = np.where([0 in y for y in Y]) |
| 59 | +one_class = np.where([1 in y for y in Y]) |
| 60 | +pl.scatter(X[:, 0], X[:, 1], s=40, c='gray') |
| 61 | +pl.scatter(X[zero_class, 0], X[zero_class, 1], s=160, edgecolors='b', |
| 62 | + facecolors='none', linewidths=2, label='Class 1') |
| 63 | +pl.scatter(X[one_class, 0], X[one_class, 1], s=80, edgecolors='orange', |
| 64 | + facecolors='none', linewidths=2, label='Class 2') |
| 65 | +pl.axis('tight') |
| 66 | + |
| 67 | +plot_hyperplane(classif.estimators_[0], min_x, max_x, 'k--', |
| 68 | + 'Boundary\nfor class 1') |
| 69 | +plot_hyperplane(classif.estimators_[1], min_x, max_x, 'k-.', |
| 70 | + 'Boundary\nfor class 2') |
| 71 | +pl.xticks(()) |
| 72 | +pl.yticks(()) |
| 73 | +pl.legend() |
| 74 | + |
| 75 | +pl.show() |
0 commit comments