美文网首页
Scikit-plot 画图神器

Scikit-plot 画图神器

作者: 顾北向南 | 来源:发表于2019-06-27 20:59 被阅读0次

    本文转载至机器学习初学者(ID:ai-start-com)
    作者:黄海广博士

    1. 安装说明

    2. 使用说明

    • 比如画出分类评级指标的ROC曲线的完整代码:
    from sklearn.datasets import load_digits
    from sklearn.model_selection import train_test_split
    from sklearn.naive_bayes import GaussianNB
    X, y = load_digits(return_X_y=True)
    X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.33)
    nb = GaussianNB()
    nb.fit(X_train, y_train)
    predicted_probas = nb.predict_proba(X_test)
    # The magic happens here
    import matplotlib.pyplot as plt
    import scikitplot as skplt
    skplt.metrics.plot_roc(y_test, predicted_probas)
    plt.show()
    
    • P-R曲线就是精确率precision vs 召回率recall 曲线,以recall作为横坐标轴,precision作为纵坐标轴。首先解释一下精确率和召回率。
    import matplotlib.pyplot as plt
    from sklearn.naive_bayes import GaussianNB
    from sklearn.datasets import load_digits as load_data
    import scikitplot as skplt
    # Load dataset
    X, y = load_data(return_X_y=True)
    # Create classifier instance then fit
    nb = GaussianNB()
    nb.fit(X,y)
    # Get predicted probabilities
    y_probas = nb.predict_proba(X)
    skplt.metrics.plot_precision_recall_curve(y, y_probas, cmap='nipy_spectral')
    plt.show()
    
    • 混淆矩阵是分类的重要评价标准,下面代码是用随机森林对鸢尾花数据集进行分类,分类结果画一个归一化的混淆矩阵。
    from sklearn.ensemble import RandomForestClassifier
    from sklearn.datasets import load_digits as load_data
    from sklearn.model_selection import cross_val_predict
    import matplotlib.pyplot as plt
    import scikitplot as skplt
    X, y = load_data(return_X_y=True)
    # Create an instance of the RandomForestClassifier
    classifier = RandomForestClassifier()
    # Perform predictions
    predictions = cross_val_predict(classifier, X, y)
    plot = skplt.metrics.plot_confusion_matrix(y, predictions, normalize=True)
    plt.show()
    

    相关文章

      网友评论

          本文标题:Scikit-plot 画图神器

          本文链接:https://www.haomeiwen.com/subject/wyydcctx.html