irpas技术客

动手画混淆矩阵(Confusion Matrix)(含代码)_我是一个对称矩阵_画混淆矩阵

网络 1126

1、混淆矩阵:Confusion Matrix2、怎么画?3、怎么用?

网上关于混淆矩阵的代码参差不齐,没找到可用的线程的代码,所以自己尝试写了下

1、混淆矩阵:Confusion Matrix

首先它长这样: 怎么看? Confusion Matrix最广泛的应用应该是分类,比如图中是7分类的真实标签和预测标签的效果。 首先图中表明了纵轴是truth label,横轴是predicted label,那么对于第一行第一个0.60的含义是:本来是angry标签的图,我的模型正确分类成angry的比例是60%,也即是angry这一类模型分类正确的精度只有60%。同时模型将angry分类成了happy的图占比0.04%,其他的以此类推。

注意:因为本身是angry,模型预测成7种类的数量占比。所以每一行的和为100%。

同时对于fear标签,模型分类成fear的占比41%,分类成sad的占比为20%,我们可以认为模型不能很好区分fear和sad两种类别。

2、怎么画?

先给出代码:

import numpy as np import matplotlib.pyplot as plt class DrawConfusionMatrix: def __init__(self, labels_name, normalize=True): """ normalize:是否设元素为百分比形式 """ self.normalize = normalize self.labels_name = labels_name self.num_classes = len(labels_name) self.matrix = np.zeros((self.num_classes, self.num_classes), dtype="float32") def update(self, predicts, labels): """ :param predicts: 一维预测向量,eg:array([0,5,1,6,3,...],dtype=int64) :param labels: 一维标签向量:eg:array([0,5,0,6,2,...],dtype=int64) :return: """ for predict, label in zip(predicts, labels): self.matrix[predict, label] += 1 def getMatrix(self,normalize=True): """ 根据传入的normalize判断要进行percent的转换, 如果normalize为True,则矩阵元素转换为百分比形式, 如果normalize为False,则矩阵元素就为数量 Returns:返回一个以百分比或者数量为元素的矩阵 """ if normalize: per_sum = self.matrix.sum(axis=1) # 计算每行的和,用于百分比计算 for i in range(self.num_classes): self.matrix[i] =(self.matrix[i] / per_sum[i]) # 百分比转换 self.matrix=np.around(self.matrix, 2) # 保留2位小数点 self.matrix[np.isnan(self.matrix)] = 0 # 可能存在NaN,将其设为0 return self.matrix def drawMatrix(self): self.matrix = self.getMatrix(self.normalize) plt.imshow(self.matrix, cmap=plt.cm.Blues) # 仅画出颜色格子,没有值 plt.title("Normalized confusion matrix") # title plt.xlabel("Predict label") plt.ylabel("Truth label") plt.yticks(range(self.num_classes), self.labels_name) # y轴标签 plt.xticks(range(self.num_classes), self.labels_name, rotation=45) # x轴标签 for x in range(self.num_classes): for y in range(self.num_classes): value = float(format('%.2f' % self.matrix[y, x])) # 数值处理 plt.text(x, y, value, verticalalignment='center', horizontalalignment='center') # 写值 plt.tight_layout() # 自动调整子图参数,使之填充整个图像区域 plt.colorbar() # 色条 plt.savefig('./ConfusionMatrix.png', bbox_inches='tight') # bbox_inches='tight'可确保标签信息显示全 plt.show()

混淆矩阵是将所有数据的label和predict整理而画的,但实际中往往是分成多个iter来推测batch_size个数据,所以需要update()函数来讲每一次的label和predict值保存进去,模型推理完成后,再调用draw()函数画出混淆矩阵并保存为图片

3、怎么用?

给出一个简单的实例:

labels_name=['angry', 'disgust', 'fear', 'happy', 'sad', 'surprise', 'neutral'] drawconfusionmatrix = DrawConfusionMatrix(labels_name=labels_name) # 实例化 for index, (labels, imgs) in enumerate(test_loader): labels_pd = model(imgs) predict_np = np.argmax(labels_pd.cpu().detach().numpy(), axis=-1) # array([0,5,1,6,3,...],dtype=int64) labels_np = labels.numpy() # array([0,5,0,6,2,...],dtype=int64) drawconfusionmatrix.update(predict_np, labels_np) # 将新批次的predict和label更新(保存) drawconfusionmatrix.drawMatrix() # 根据所有predict和label,画出混淆矩阵 confusion_mat=drawconfusionmatrix.getMatrix() # 你也可以使用该函数获取混淆矩阵(ndarray) print(confusion) cpu().detach():从device上获取数据.numpy():将tensor类型转换为numpy类型

在我的模型上的结果:


1.本站遵循行业规范,任何转载的稿件都会明确标注作者和来源;2.本站的原创文章,会注明原创字样,如未注明都非原创,如有侵权请联系删除!;3.作者投稿可能会经我们编辑修改或补充;4.本站不提供任何储存功能只提供收集或者投稿人的网盘链接。

标签: #画混淆矩阵 #1混淆矩阵Confusion #首先图中表明了纵轴是truth