美文网首页数据可视化呆鸟的Python数据分析数据可视化分析
数据可视化:pandas透视图、seaborn热力图

数据可视化:pandas透视图、seaborn热力图

作者: 273123e8cd8a | 来源:发表于2020-02-22 16:14 被阅读0次

    1. 创建需要展示的数据

    import itertools
    
    import matplotlib.pyplot as plt
    import numpy as np
    import pandas as pd
    import seaborn as sns
    
    # === define paras ==================
    para_names = ['layer_n', 'activition', 'seed']
    
    layer_n = [1, 2, 3, 4, 5, 6]
    activition = ['tanh', 'sigmod', 'relu']
    seed = [11, 17, 19]
    
    # 创建 dataframe
    df =  pd.DataFrame([], columns=para_names)
    for values in itertools.product(layer_n, activition, seed):
        newline = pd.DataFrame(list(values), index=para_names)
        df = df.append(newline.T, ignore_index=True)
    
    # 伪造一些训练结果,方便展示
    # activ_2_num = pd.factorize(df['activition'])[0].astype('int')  # 激活函数是字符类型,将其映射成整数形
    activ_dict = {'tanh': 2, 'sigmod': 4, 'relu': 6}  # 也可以直接定义字典,然后replace
    df['results'] = df['layer_n'] + df['activition'].replace(activ_dict) + df['seed'] * 0.1 + np.random.random((54,))
    df['results'] = df['results'].astype('float')  # 转换成浮点类型
    print(df.head())
    

    输出:

      layer_n activition seed   results
    0       1       tanh   11  4.261361
    1       1       tanh   17  4.822595
    2       1       tanh   19  4.929088
    3       1     sigmod   11  6.698047
    4       1     sigmod   17  7.020531
    

    2. 绘制带误差的折线图展示训练结果

    # 绘制带误差的折线图,横轴为网络层数,纵轴为训练结果,
    # 激活函数采用不同颜色的线型,误差来自于没有指定的列:不同的随机种子seed
    plt.figure(figsize=(8, 6))
    sns.lineplot(x='layer_n', y='results', hue='activition',  style='activition', 
                 markers=True, data=df)
    plt.grid(linestyle=':')
    plt.show()
    

    3. 使用pandas透视图、seaborn热力图来展示

    # 创建透视图,
    # 对于没有指定的列(seed),按最大值进行统计
    dt = pd.pivot_table(df, index=['layer_n'], columns=['activition'], values=['results'], aggfunc=[max])
    print(dt)
    print(dt.columns)  
    
    # 找到最大值、最大值所对应的索引
    max_value, max_idx = dt.stack().max(), dt.stack().idxmax()
    print(f' - the max value is {max_value};\n - the index is {max_idx}...')
    
    # 透视图变成了多重索引(MultiIndex),重新调整一下
    new_col = dt.columns.levels[2]
    dt.columns = new_col
    # dt.index = list(dt.index)
    print(dt)
    
    dt.sort_index(axis=0, ascending=False, inplace=True)  # 必要时将索引重新排序
    dt.sort_index(axis=1, ascending=False, inplace=True)  # 必要时将索引重新排序
    
    # 绘制热力图,横轴为网络层数,纵轴为激活函数,
    # 栅格的颜色代表训练结果,颜色越深结果越好
    plt.figure(figsize=(8, 6))
    g = sns.heatmap(dt, vmin=0.0, annot=True, fmt='.2g', cmap='Blues', cbar=True)
    plt.show()
    

    ref:

    相关文章

      网友评论

        本文标题:数据可视化:pandas透视图、seaborn热力图

        本文链接:https://www.haomeiwen.com/subject/anflqhtx.html