美文网首页数据分析的那些事
数据分析03 - 决策树

数据分析03 - 决策树

作者: 数据社 | 来源:发表于2019-10-14 16:56 被阅读0次

    标签(空格分隔): 数据分析


    CART 创建决策树做分类

    # encoding=utf-8
    from sklearn.model_selection import train_test_split
    from sklearn.metrics import accuracy_score
    from sklearn.tree import DecisionTreeClassifier
    from sklearn.datasets import load_iris
    import graphviz
    from sklearn import tree
    import os
    os.environ["PATH"] += os.pathsep + 'C:\\Users\\qincf\\AppData\\Local\\Continuum\\anaconda3\Library\\bin\\graphviz'
    
    # 准备数据集
    iris=load_iris()
    # 获取特征集和分类标识
    features = iris.data
    labels = iris.target
    # 随机抽取 33% 的数据作为测试集,其余为训练集
    train_features, test_features, train_labels, test_labels = train_test_split(features, labels, test_size=0.33, random_state=0)
    # 创建 CART 分类树
    clf = DecisionTreeClassifier(criterion='gini')
    # 拟合构造 CART 分类树
    clf = clf.fit(train_features, train_labels)
    # 用 CART 分类树做预测
    test_predict = clf.predict(test_features)
    print(test_predict)
    # 预测结果与测试集结果作比对
    score = accuracy_score(test_labels, test_predict)
    print("CART 分类树准确率 %.4lf" % score)
    
    ##打印CART
    dot_data = tree.export_graphviz(clf,out_file=None)
    graph = graphviz.Source(dot_data)
    graph
    print(graph.view())
    

    分类树如下:

    image.png

    CART 回归树做预测

    # encoding=utf-8
    from sklearn.model_selection import train_test_split
    from sklearn.metrics import mean_squared_error,mean_absolute_error
    
    from sklearn.tree import DecisionTreeRegressor
    from sklearn.datasets import load_boston
    import graphviz
    from sklearn import tree
    import os
    os.environ["PATH"] += os.pathsep + 'C:\\Users\\qincf\\AppData\\Local\\Continuum\\anaconda3\Library\\bin\\graphviz'
    
    
    # 准备数据集
    boston=load_boston()
    # 探索数据
    print(boston.feature_names)
    # 获取特征集和房价
    features = boston.data
    prices = boston.target
    # 随机抽取 33% 的数据作为测试集,其余为训练集
    train_features, test_features, train_price, test_price = train_test_split(features, prices, test_size=0.33)
    # 创建 CART 回归树
    dtr=DecisionTreeRegressor()
    # 拟合构造 CART 回归树
    dtr.fit(train_features, train_price)
    # 预测测试集中的房价
    predict_price = dtr.predict(test_features)
    print(test_features)
    # 测试集的结果评价
    print('回归树二乘偏差均值:', mean_squared_error(test_price, predict_price))
    print('回归树绝对值偏差均值:', mean_absolute_error(test_price, predict_price))
    

    相关文章

      网友评论

        本文标题:数据分析03 - 决策树

        本文链接:https://www.haomeiwen.com/subject/vunrmctx.html