美文网首页Deep Learning
scikit-learn中classification_repo

scikit-learn中classification_repo

作者: 王二牛牛 | 来源:发表于2018-01-23 09:54 被阅读1513次
    from sklearn.metrics import classification_report, accuracy_score
    classification_report(y_test, y_pre, target_names=target_names)
    

    scikit-learn中的classification_report是强大的函数,可以计算查全率,查准率,F1参数,keras中没有相关的函数,并且keraslabel为one-hot,输出的为[0.3.0.2,0.5]这样的softmax数据,如何转化为[4,5,5]这样的标签数据用于适配classification_report函数。

    1、one-hot转化为整数label

    代码如下:

    import numpy as np
    def onehot_to_category(onehot):
          b = np.array([[0], [1], [2],[3],[4],[5]])
          return np.dot(onehot,b).flatten()
    

    原理很简单,矩阵的乘法
    输入时n*6的矩阵,n个样本,一共6类,6代表onehot编码如[0,0,0,0,0,1]

    b为列向量,shape为6*1

    np.dot(a,b)代表a矩阵与b矩阵的乘法,输出矩阵为n*1,即转化为整数的lable形式

    2、softmax输出转化为整数label

    代码如下
    def softmax_to_category(a):
    max2 = []
    for item in a:
    i=np.argmax(item)
    max2.append(i)
    return max2
    比较low的方法,诸葛数据取最大值,取最大值的索引,就是整数形式的label

    相关文章

      网友评论

        本文标题:scikit-learn中classification_repo

        本文链接:https://www.haomeiwen.com/subject/qggjaxtx.html