scikit-learn决策树算法 鸢尾花数据集

作者: fred_33c7 | 来源:发表于2019-04-30 12:56 被阅读4次

    1.scikit-learn决策树算法类库

    scikit-learn决策树算法类库内部实现是使用了调优过的CART树算法,既可以做分类,又可以做回归。分类决策树的类对应的是DecisionTreeClassifier,而回归决策树的类对应的是DecisionTreeRegressor。

    2. 程序详解

    1. 首先导入需要的库

    import numpy as np
    import matplotlib.pyplot as plt
    
    from sklearn.datasets import load_iris
    from sklearn.tree import DecisionTreeClassifier
    

    2.程序和备注

    import numpy as np
    import matplotlib.pyplot as plt
    
    from sklearn import datasets
    from sklearn.tree import DecisionTreeClassifier
    
    
    # 仍然使用自带的iris数据
    iris = datasets.load_iris()
    X = iris.data[:, [2, 3]]
    y = iris.target
    
    # 训练模型,限制树的最大深度4
    clf = DecisionTreeClassifier(max_depth=4)
    #拟合模型
    clf.fit(X, y)
    
    
    # 画图
    x_min, x_max = X[:, 0].min() - 1, X[:, 0].max() + 1
    y_min, y_max = X[:, 1].min() - 1, X[:, 1].max() + 1
    xx, yy = np.meshgrid(np.arange(x_min, x_max, 0.1),
                         np.arange(y_min, y_max, 0.1))
    
    Z = clf.predict(np.c_[xx.ravel(), yy.ravel()])
    Z = Z.reshape(xx.shape)
    # 绘制等高线plt.contourf() 区域颜色填充 alpha 指透明度
    plt.contourf(xx, yy, Z, alpha=0.4)
    plt.scatter(X[:, 0], X[:, 1], c=y, alpha=0.8)
    plt.show()
    

    2.结果图

    结果图

    相关文章

      网友评论

        本文标题:scikit-learn决策树算法 鸢尾花数据集

        本文链接:https://www.haomeiwen.com/subject/bdemnqtx.html