您当前的位置:首页 > IT编程 > 深度学习
| C语言 | Java | VB | VC | python | Android | TensorFlow | C++ | oracle | 学术与代码 | cnn卷积神经网络 | gnn | 图像修复 | Keras | 数据集 | Neo4j | 自然语言处理 | 深度学习 | 医学CAD | 医学影像 | 超参数 | pointnet | pytorch | 异常检测 | Transformers | 情感分类 | 知识图谱 |

自学教程:python绘制多分类混淆矩阵

51自学网 2020-11-13 07:50:58
  深度学习
这篇教程python绘制多分类混淆矩阵写得很实用,希望能帮到您。

对于分类问题,一般使用混淆矩阵来分析各类别预测的结果,可视化混淆矩阵来分析预测结果从而得到调参思路。

 

绘制多分类混淆矩阵

%matplotlib inline
from sklearn.metrics import confusion_matrix
import matplotlib.pyplot as plt
import numpy as np

def plot_confusion_matrix(cm, savename, title='Confusion Matrix'):

    plt.figure(figsize=(12, 8), dpi=100)
    np.set_printoptions(precision=2)

    # 在混淆矩阵中每格的概率值
    ind_array = np.arange(len(classes))
    x, y = np.meshgrid(ind_array, ind_array)
    for x_val, y_val in zip(x.flatten(), y.flatten()):
        c = cm[y_val][x_val]
        if c > 0.001:
            plt.text(x_val, y_val, "%0.2f" % (c,), color='red', fontsize=15, va='center', ha='center')
    
    plt.imshow(cm, interpolation='nearest', cmap=plt.cm.binary)
    plt.title(title)
    plt.colorbar()
    xlocations = np.array(range(len(classes)))
    plt.xticks(xlocations, classes, rotation=90)
    plt.yticks(xlocations, classes)
    plt.ylabel('Actual label')
    plt.xlabel('Predict label')
    
    # offset the tick
    tick_marks = np.array(range(len(classes))) + 0.5
    plt.gca().set_xticks(tick_marks, minor=True)
    plt.gca().set_yticks(tick_marks, minor=True)
    plt.gca().xaxis.set_ticks_position('none')
    plt.gca().yaxis.set_ticks_position('none')
    plt.grid(True, which='minor', linestyle='-')
    plt.gcf().subplots_adjust(bottom=0.15)
    
    # show confusion matrix
    plt.savefig(savename, format='png')
    plt.show()

 

获取实际标签、预测结果和混淆矩阵:

# classes表示不同类别的名称,比如这有6个类别
classes = ['A', 'B', 'C', 'D', 'E', 'F']

random_numbers = np.random.randint(6, size=50)  # 6个类别,随机生成50个样本
y_true = random_numbers.copy()  # 样本实际标签
random_numbers[:10] = np.random.randint(6, size=10)  # 将前10个样本的值进行随机更改
y_pred = random_numbers  # 样本预测标签

# 获取混淆矩阵
cm = confusion_matrix(y_true, y_pred)
plot_confusion_matrix(cm, 'confusion_matrix.png', title='confusion matrix')

 

得到如下图所示:

 

我们看类别A,预测结果和实际标签都为A的有12个样本,把A样本预测为其他类别的有3个样本(同一行的其他样本),而把其他类别预测为A样本的有1个样本(同一列的其他样本)。其他类别也同样这样分析。

 

通常我们会在绘图前对混淆矩阵按行做一个标准化处理,即得到的是概率值,每行所有的概率之和为1,所以对角线就代表每个类别的查全率(召回率)。

# Normalize by row
cm_normalized = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]
print(cm_normalized)

 

输出为如下形式的二维矩阵(和上面的混淆矩阵并不对应):

[[0.7  0.   0.   0.   0.1  0.2 ]
 [0.12 0.88 0.   0.   0.   0.  ]
 [0.   0.   0.71 0.   0.29 0.  ]
 [0.   0.   0.   0.83 0.17 0.  ]
 [0.   0.   0.   0.   1.   0.  ]
 [0.   0.   0.   0.   0.15 0.85]]

 


keras CNN卷积核可视化,热度图 python源代码
classification_report 保留多位小数
万事OK自学网:51自学网_软件自学网_CAD自学网自学excel、自学PS、自学CAD、自学C语言、自学css3实例,是一个通过网络自主学习工作技能的自学平台,网友喜欢的软件自学网站。