美文网首页
CART回归树python实现

CART回归树python实现

作者: yxwithu | 来源:发表于2017-11-30 10:08 被阅读0次
    import numpy as np
    
    def reg_leaf(dataSet):
        """生成叶子节点,返回这个叶子上样本标记的平均值"""
        return np.mean(dataSet[:, -1])
    
    def reg_err(dataSet):
        """计算平方误差"""
        return np.var(dataSet[:, -1]) * shape(dataSet)[0]
    
    def bin_split_data(dataSet, feat_index, split_value):
        """二分数据集"""
        arr0 = []
        arr1 = []
        for i in range(shape(dataSet)[0]):
            if dataSet[i, feat_index] <= split_value:
                arr0.append(dataSet[i])
            else:
                arr1.append(dataSet[i])
        return np.mat(arr0), np.mat(arr1)
    
    def choose_best_split(dataSet, leaf_type=reg_leaf, err_type = reg_err, ops=(1,4)):
        """选择最优分裂节点和分裂值"""
        tolS = ops[0]  #误差减少阈值,达到tolS才允许分裂
        tolN = ops[1]  #最少分割样本,达到tolN才允许分裂
        
        if len(set(dataSet[:,-1])) == 1 or len(dataSet) <= tolN:  #样本的值相等,没必要分割了
            return None, leaf_type(dataSet)
        
        m, n = shape(dataSet)
        ori_err = reg_err(dataSet)
        
        lowest_err = np.inf
        best_index = -1
        best_value = -1
        
        for index in range(n - 1):
            for value in set(dataSet[:, index]):
                mat0, mat1 = bin_split_data(dataSet, index, value)
                if shape(mat0)[0] < tolN or shape(mat1) < tolN:  #子节点样本数过少
                    continue
                new_err = reg_err(mat0, mat1)  #子树的平方误差和
                if new_err < lowest_err:
                    lowest_err = new_err
                    best_index = index
                    best_value = value
        
        if best_index == -1 or (ori_err - lowest_err) < tolS:  #误差减少太小
            return None, leaf_type(dataSet)
        
        return best_index, best_value
    
    def create_tree(dataSet, feat_names):
        """递归创建一棵决策树"""
        y_train = [row[-1] for row in dataSet]
        if len(set(y_train)) == 1:  #类别完全相同,停止继续划分,返回类别
            return y_train[0]
        if len(dataSet[0]) == 1:  #没有特征可以划分了,直接返回最多的特征
            return get_majority_cnt_label(y_train)
        
        best_feat = choose_best_feat_ent_ratio(dataSet)  #找到最优分割特征
        best_feat_name = feat_names[best_feat]
        
        myTree = {best_feat_name:{}}  #开始构建二叉树
        del feat_names[best_feat]
        feat_data_dict = splitDataSet(col_id=best_feat, dataSet=dataSet)
        for feat_value, data in feat_data_dict.items():
            sub_feat_names = feat_names[:]  #拷贝赋值,防止被修改
            myTree[best_feat_name][feat_value] = create_tree(data, sub_feat_names)   #保证传进去的不是空的数据集
        return myTree
    

    相关文章

      网友评论

          本文标题:CART回归树python实现

          本文链接:https://www.haomeiwen.com/subject/hsombxtx.html