scikit learn で cross validation で confusion matrix を取得する。
背景
confusion matrix を取得する場合は、一部の例だけでなく cross validation で全ての事例に対して取得したい。
対応
sklearn.model_selection.cross_val_predict を使う。
例
今回は iris のデータを使う。当たり前だが、confusion matrix は普通分類問題に使うよね?
from sklearn import datasets iris = datasets.load_iris() X = iris.data y = iris.target
にも関わらず、例だから、logistic regression などを使ってみる。
from sklearn.linear_model import LogisticRegression logistic = LogisticRegression()
ここで、cross validation 登場。 cv
パラメータは、KFold(n_splits=10, shuffle=True)
に設定している*1。推定結果と教師データから、confusion matrix を生成する。
from sklearn.model_selection import (cross_val_predict, KFold) from sklearn.metrics import confusion_matrix y_pred = cross_val_predict(logistic,X,y,cv=KFold(n_splits=10, shuffle=True)) conf_mat = confusion_matrix(y,y_pred)
コードを書かなくていいし、そのコードのバグも心配しなくていい*2。検証するコードがバグっていて、何を色々努力してもダメ、というのは最悪なパターンだと思うけど、 機械学習ではよくある話のような気がする。
そして、よくある、 confusion matrix を plot するコード。
# scikit-learn.org/stable/auto_examples/model_selection/plot_confusion_matrix.html import matplotlib.pyplot as plt %matplotlib inline import itertools import numpy as np def plot_confusion_matrix(cm, classes, normalize=False, title='Confusion matrix', cmap=plt.cm.Blues): """ This function prints and plots the confusion matrix. Normalization can be applied by setting `normalize=True`. """ if normalize: cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis] print("Normalized confusion matrix") else: print('Confusion matrix, without normalization') print(cm) plt.imshow(cm, interpolation='nearest', cmap=cmap) plt.title(title) plt.colorbar() tick_marks = np.arange(len(classes)) plt.xticks(tick_marks, classes, rotation=45) plt.yticks(tick_marks, classes) fmt = '.2f' if normalize else 'd' thresh = cm.max() / 2. for i, j in itertools.product(range(cm.shape[0]), range(cm.shape[1])): plt.text(j, i, format(cm[i, j], fmt), horizontalalignment="center", color="white" if cm[i, j] > thresh else "black") plt.tight_layout() plt.ylabel('True label') plt.xlabel('Predicted label')
そして実行+結果
plot_confusion_matrix(conf_mat, list(iris.target_names))
Confusion matrix, without normalization
[[50 0 0]
[ 0 45 5]
[ 0 2 48]]
まとめ
confusion matrix を取得したいときは、sklearn.model_selection.cross_val_predict を使う。 このcolab
広告
*1:詳しくは scikit learn の Kfold, StratifiedKFold, ShuffleSplit の違い - 中野智文のブログ
*2:と書いたけどこれは嘘。心配(注意)はした方がいい。ただ有名なコードなら多くの人が心配してくれるし、直してもくれるから安心