使用python matplotlib绘制混淆矩阵
今天使用了python matplotlib包,绘制混淆矩阵。基本代码参考官网教程,在此基础上增加了格网显示。
代码说明:
cm - 混淆矩阵的数值, 是一个二维numpy数组
classes - 各个类别的标签(label)
title - 图片标题
cmap - 颜色图
def plot_Matrix(cm, classes, title=None, cmap=plt.cm.Blues):
plt.rc('font',family='Times New Roman',size='8') # 设置字体样式、大小
# 按行进行归一化
cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]
print("Normalized confusion matrix")
str_cm = cm.astype(np.str).tolist()
for row in str_cm:
print('\t'.join(row))
# 占比1%以下的单元格,设为0,防止在最后的颜色中体现出来
for i in range(cm.shape[0]):
for j in range(cm.shape[1]):
if int(cm[i, j]*100 + 0.5) == 0:
cm[i, j]=0
fig, ax = plt.subplots()
im = ax.imshow(cm, interpolation='nearest', cmap=cmap)
# ax.figure.colorbar(im, ax=ax) # 侧边的颜色条带
ax.set(xticks=np.arange(cm.shape[1]),
yticks=np.arange(cm.shape[0]),
xticklabels=classes, yticklabels=classes,
title=title,
ylabel='Actual',
xlabel='Predicted')
# 通过绘制格网,模拟每个单元格的边框
ax.set_xticks(np.arange(cm.shape[1]+1)-.5, minor=True)
ax.set_yticks(np.arange(cm.shape[0]+1)-.5, minor=True)
ax.grid(which="minor", color="gray", linestyle='-', linewidth=0.2)
ax.tick_params(which="minor", bottom=False, left=False)
# 将x轴上的lables旋转45度
plt.setp(ax.get_xticklabels(), rotation=45, ha="right",
rotation_mode="anchor")
# 标注百分比信息
fmt = 'd'
thresh = cm.max() / 2.
for i in range(cm.shape[0]):
for j in range(cm.shape[1]):
if int(cm[i, j]*100 + 0.5) > 0:
ax.text(j, i, format(int(cm[i, j]*100 + 0.5) , fmt) + '%',
ha="center", va="center",
color="white" if cm[i, j] > thresh else "black")
fig.tight_layout()
plt.savefig('cm.jpg', dpi=300)
plt.show()
最后,画出来的效果如下:
由于四舍五入的原因,导致每行加起来的结果有时候不为1,可能需要手动调整一下~