中野智文のブログ

データ・マエショリストのメモ

scikit learn で cross validation で confusion matrix を取得する。

背景

confusion matrix を取得する場合は、一部の例だけでなく cross validation で全ての事例に対して取得したい。

対応

sklearn.model_selection.cross_val_predict を使う。

今回は iris のデータを使う。当たり前だが、confusion matrix は普通分類問題に使うよね?

from sklearn import datasets
iris = datasets.load_iris()
X = iris.data
y = iris.target

にも関わらず、例だから、logistic regression などを使ってみる。

from sklearn.linear_model import LogisticRegression
logistic = LogisticRegression()

ここで、cross validation 登場。 cv パラメータは、KFold(n_splits=10, shuffle=True) に設定している*1。推定結果と教師データから、confusion matrix を生成する。

from sklearn.model_selection import cross_val_predict
from sklearn.metrics import confusion_matrix
y_pred = cross_val_predict(logistic,X,y,cv=KFold(n_splits=10, shuffle=True))
conf_mat = confusion_matrix(y,y_pred)

コードを書かなくていいし、そのコードのバグも心配しなくていい*2。検証するコードがバグっていて、何を色々努力してもダメ、というのは最悪なパターンだと思うけど、 機械学習ではよくある話のような気がする。

そして、よくある、 confusion matrix を plot するコード。

# scikit-learn.org/stable/auto_examples/model_selection/plot_confusion_matrix.html
import matplotlib.pyplot as plt
%matplotlib inline
import itertools
import numpy as np
def plot_confusion_matrix(cm, classes, normalize=False, title='Confusion matrix', cmap=plt.cm.Blues):
    """
    This function prints and plots the confusion matrix.
    Normalization can be applied by setting `normalize=True`.
    """
    if normalize:
        cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]
        print("Normalized confusion matrix")
    else:
        print('Confusion matrix, without normalization')
    print(cm)
    plt.imshow(cm, interpolation='nearest', cmap=cmap)
    plt.title(title)
    plt.colorbar()
    tick_marks = np.arange(len(classes))
    plt.xticks(tick_marks, classes, rotation=45)
    plt.yticks(tick_marks, classes)
    fmt = '.2f' if normalize else 'd'
    thresh = cm.max() / 2.
    for i, j in itertools.product(range(cm.shape[0]), range(cm.shape[1])):
        plt.text(j, i, format(cm[i, j], fmt),
                 horizontalalignment="center",
                 color="white" if cm[i, j] > thresh else "black")
    plt.tight_layout()
    plt.ylabel('True label')
    plt.xlabel('Predicted label')

そして実行+結果

plot_confusion_matrix(conf_mat, list(iris.target_names))
Confusion matrix, without normalization
[[50  0  0]
 [ 0 45  5]
 [ 0  2 48]]

f:id:nakano-tomofumi:20171109141252p:plain

まとめ

confusion matrix を取得したいときは、sklearn.model_selection.cross_val_predict を使う。

*1:詳しくは scikit learn の Kfold, StratifiedKFold, ShuffleSplit の違い - 中野智文のブログ

*2:と書いたけどこれは嘘。心配(注意)はした方がいい。ただ有名なコードなら多くの人が心配してくれるし、直してもくれるから安心