美文网首页大数据 爬虫Python AI SqlTensorFlow技术帖
数据可视化-混淆矩阵(confusion matrix)

数据可视化-混淆矩阵(confusion matrix)

作者: 洗洗睡吧i | 来源:发表于2019-09-25 12:27 被阅读0次

    1. 混淆矩阵(confusion matrix)介绍

    在基于深度学习的分类识别领域中,经常采用统计学中的混淆矩阵(confusion matrix)来评价分类器的性能。

    它是一种特定的二维矩阵:

    • 列代表预测的类别;行代表实际的类别。
    • 对角线上的值表示预测正确的数量/比例;非对角线元素是预测错误的部分。

    混淆矩阵的对角线值越高越好,表明许多正确的预测。

    特别是在各分类数据的数量不平衡的情况下,混淆矩阵可以直观的显示分类模型对应各个类别的准确率。

    ref: https://scikit-learn.org/stable/auto_examples/model_selection/plot_confusion_matrix.html

    2. 混淆矩阵示列

    • 数据集: MNIST
    • tensorflow,keras,
    • 神经网络:CNN

    依赖:kerasmatplotlibnumpyseaborntensorflowsklearn

    import keras
    import matplotlib.pyplot as plt
    import numpy as np
    import seaborn as sns
    
    from sklearn.metrics import confusion_matrix
    
    # === dataset ===
    with np.load('mnist.npz') as f:
        x_train, y_train = f['x_train'], f['y_train']
        x_test, y_test = f['x_test'], f['y_test']
    
    x_train = x_train.reshape(60000, 28, 28, 1)
    x_test = x_test.reshape(10000, 28, 28, 1)
    print(x_train.shape)
    print(x_test.shape)
    
    # === model: CNN ===
    model = keras.models.Sequential()
    model.add(keras.layers.Conv2D(32, (3, 3), activation='relu', input_shape=(28, 28, 1)))
    model.add(keras.layers.MaxPooling2D((2, 2)))
    model.add(keras.layers.Conv2D(64, (3, 3), activation='relu'))
    model.add(keras.layers.MaxPooling2D((2, 2)))
    model.add(keras.layers.Flatten())
    model.add(keras.layers.Dense(64, activation='relu'))
    model.add(keras.layers.Dense(10, activation='softmax'))
    model.compile(optimizer='adam',
                  loss='sparse_categorical_crossentropy',
                  metrics=['accuracy'])
    model.summary()
    
    # === train ===
    model.fit(x=x_train, y=y_train,
              batch_size=512,
              epochs=10,
              validation_data=(x_test, y_test))
    
    # === pred ===
    y_pred = model.predict_classes(x_test)
    print(y_pred)
    
    # === 混淆矩阵:真实值与预测值的对比 ===
    # https://scikit-learn.org/stable/auto_examples/model_selection/plot_confusion_matrix.html
    con_mat = confusion_matrix(y_test, y_pred)
    
    con_mat_norm = con_mat.astype('float') / con_mat.sum(axis=1)[:, np.newaxis]     # 归一化
    con_mat_norm = np.around(con_mat_norm, decimals=2)
    
    # === plot ===
    plt.figure(figsize=(8, 8))
    sns.heatmap(con_mat_norm, annot=True, cmap='Blues')
    
    plt.ylim(0, 10)
    plt.xlabel('Predicted labels')
    plt.ylabel('True labels')
    plt.show()
    
    Figure_1.png

    相关文章

      网友评论

        本文标题:数据可视化-混淆矩阵(confusion matrix)

        本文链接:https://www.haomeiwen.com/subject/lpnnectx.html