美文网首页
python画热力图

python画热力图

作者: ltochange | 来源:发表于2021-08-01 22:32 被阅读0次

    python中可使用seaborn.heatmap画热力图,官方文档在这

    在分类任务中,也可用于画混淆矩阵:

    import numpy as np
    import seaborn as sns
    import pandas as pd
    import matplotlib.pyplot as plt
    
    
    def confusion_matrix(y_true, y_pred, labels=None):
        n = len(labels)
        labels_dict = {label: i for i, label in enumerate(labels)}
        res = np.zeros([n, n], dtype=np.int32)
        for gold, predict in zip(y_true, y_pred):
            res[labels_dict[gold]][labels_dict[predict]] += 1
    
        df = pd.DataFrame(res, index=labels, columns=labels)
        sns.heatmap(df, annot=True, fmt='d')
        plt.savefig("./confusion_matrix.jpg")
        plt.show()
    
    y_true = ["cat", "ant", "cat", "cat", "ant", "bird"]  # 真实
    y_pred = ["ant", "ant", "cat", "cat", "ant", "cat"]  # 预测
    labels = ["ant", "bird", "cat"]
    
    confusion_matrix(y_true, y_pred, labels)
    
    
    在这里插入图片描述

    一些参数的含义:

    def heatmap(
        data, *,
        vmin=None, vmax=None, cmap=None, center=None, robust=False,
        annot=None, fmt=".2g", annot_kws=None,
        linewidths=0, linecolor="white",
        cbar=True, cbar_kws=None, cbar_ax=None,
        square=False, xticklabels="auto", yticklabels="auto",
        mask=None, ax=None,
        **kwargs
    )
    
    • 根据data传入的值画出热力图,一般是二维矩阵
    • vmin设置最小值, vmax设置最大值
    • cmap换用不同的颜色
    • center设置中心值
    • annot 是否在方格上写上对应的数字
    • fmt 写入热力图的数据类型,默认为科学计数,d表示整数,.1f表示保留一位小数
    • linewidths 设置方格之间的间隔
    • xticklabels,yticklabels填到横纵坐标的值。可以是bool,填或者不填。可以是int,以什么间隔填,可以是list

    例子:

    import numpy as np
    np.random.seed(0)
    import seaborn as sns
    sns.set_theme()
    uniform_data = np.random.rand(10, 12)
    ax = sns.heatmap(uniform_data)
    
    在这里插入图片描述

    将最后一行改为,设置最大值和最小值:

    ax = sns.heatmap(uniform_data, vmin=0, vmax=1)
    
    在这里插入图片描述

    设置中心值:

    normal_data = np.random.randn(10, 12)
    ax = sns.heatmap(normal_data, center=0)
    
    在这里插入图片描述

    从文件中获取数据,并画图给出有意义的横纵坐标:

    flights = sns.load_dataset("flights")
    flights = flights.pivot("month", "year", "passengers")
    ax = sns.heatmap(flights)
    
    在这里插入图片描述

    将passengers对应的人数标出:

    ax = sns.heatmap(flights, annot=True, fmt="d")
    
    在这里插入图片描述

    设置方格之间的间隔:

    ax = sns.heatmap(flights, linewidths=.5)
    
    在这里插入图片描述

    设置使用不同的颜色:

    ax = sns.heatmap(flights, cmap="YlGnBu")
    
    在这里插入图片描述

    以某个具体的数据为中心:

    ax = sns.heatmap(flights, center=flights.loc["Jan", 1955])
    
    在这里插入图片描述

    自动填充坐标值:

    data = np.random.randn(50, 20)
    ax = sns.heatmap(data, xticklabels=2, yticklabels=False)
    
    在这里插入图片描述

    不画右边的热度条:

    ax = sns.heatmap(flights, cbar=False)
    
    在这里插入图片描述

    相关文章

      网友评论

          本文标题:python画热力图

          本文链接:https://www.haomeiwen.com/subject/fborvltx.html