—————————————————–
confusion matrix を使って、どこで誤分類が起きたのかを可視化
—————————————————–
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 |
In: from sklearn.metrics import confusion_matrix class_names = ([3,8]) cm = confusion_matrix(expected, predicted) print(cm) # Plot non-normalized confusion matrix plt.figure() plot_confusion_matrix(cm, classes=class_names, title='Confusion matrix, without normalization') plt.figure() plot_confusion_matrix(cm, normalize=True, classes=class_names, title='Normalized confusion matrix') plt.show() |
—————————————————–
[[60 15]
[ 2 66]]
Confusion matrix, without normalization
[[60 15]
[ 2 66]]
Normalized confusion matrix
[[ 0.8 0.2 ]
[ 0.03 0.97]]