美文网首页算法
人工智能-决策树ID3算法

人工智能-决策树ID3算法

作者: zhangxuanchen | 来源:发表于2017-10-16 13:07 被阅读30次

    决策树是机器学习中非常经典的一类学习算法,它通过树的结构,利用树的分支来表示对样本特征的判断规则,从树的叶子节点所包含的训练样本中得到预测值。决策树如何生成决定了所能处理的数据类型和预测性能。主要的决策树算法包括ID3,C4.5, CART等。
    1,ID3

    ID3是由 Ross Quinlan在1986年提出的一种构造决策树的方法。用于处理标称型数据集,其构造过程如下:
    输入训练数据是一组带有类别标记的样本,构造的结果是一棵多叉树。树的分支节点一般表示为一个逻辑判断,如形式为a=aj的逻辑判断,其中a是属性,aj是该属性的所有取值。
    在该节点上选取能对该节点处的训练数据进行最优划分的属性。最后划分的标准是信息增益(Information Gain),即划分前后数据集的熵的差异。
    如果在该节点的父节点或者祖先中用了某个属性,则这个用过的属性就不再使用。选择好最优属性后,假设该属性有N个取值,则为该节点建立N个分支,将相应的训练数据传递到这N个分支中,递归进行,停止条件为:
    (1)该节点的所有样本属于同一类,该节点成为叶节点,存放相应的类别。
    (2)所有的属性都已被父节点或祖先使用。这种情况下,该节点成为叶节点,并以样本中元组个数最多的类别作为类别标记,同时也可以存放该结点样本的类别分布。
    ID3的特点是:(1),容易造成过度拟合。(2), 使用标称型数据,但是很难处理连续型数据。
    2, C4.5

    C4.5是对ID3的改进,其基本过程与ID3类似,改进的地方在于:
    (1)既能处理标称型数据,又能连续型数据。为了处理连续型数据,该算法在相应的节点使用一个属性的阈值,利用阈值将样本划分成两部分。
    (2)能处理缺失了一些属性的数据。该算法允许属性值缺失时被标记为?,属性值缺失的样本在计算熵增益时被忽略。
    (3)构造完成后可以剪枝。合并相邻的无法产生大量信息增益的叶节点,消除过渡匹配问题。
    3,CART

    CART称为分类决策树,classification and regression tree,既能处理分类问题,又能处理回归问题。最初由Breiman提出。与ID3不能直接处理连续型特征不同的是,CART使用二元切分,即使用一个属性阈值对样本数据进行划分。划分的标准除了使用熵增益外,还有基尼纯净度(Gini impurity)和方差缩减(variance reduction)(用于回归)。

    http://blog.csdn.net/taigw/article/details/44840771

    package com.company;
    
    import java.io.BufferedReader;
    import java.io.File;
    import java.io.FileReader;
    import java.io.IOException;
    import java.util.*;
    import java.util.regex.Matcher;
    import java.util.regex.Pattern;
    
    public class ID3 {
    
        private ArrayList<String> attribute = new ArrayList<>(); // 存储属性的名称
        private ArrayList<ArrayList<String>> attributevalue = new ArrayList<>(); // 存储每个属性的取值
        private ArrayList<String[]> data = new ArrayList<>();; // 原始数据
        int decatt; // 决策变量在属性集中的索引
        public static final String patternString = "@attribute(.*)[{](.*?)[}]";
        private DecisionTreeNode rootNode;
    
        public ID3() {
            rootNode = new DecisionTreeNode();
            rootNode.nodeName = "root";
            rootNode.nodeValue = null;
        }
    
        public void getLearnData(File file){
            try {
                FileReader fr = new FileReader(file);
                BufferedReader br = new BufferedReader(fr);
                String line;
                Pattern pattern = Pattern.compile(patternString);
                while ((line = br.readLine()) != null) {
                    Matcher matcher = pattern.matcher(line);
                    if (matcher.find()) {
                        attribute.add(matcher.group(1).trim());
                        String[] values = matcher.group(2).split(",");
                        ArrayList<String> al = new ArrayList<String>(values.length);
                        for (String value : values) {
                            al.add(matcher.group(1).trim() + ":" + value.trim());
                        }
                        attributevalue.add(al);
                    } else if (line.startsWith("@data")) {
                        while ((line = br.readLine()) != null) {
                            if("".equals(line))
                                continue;
                            String[] row = line.split(",");
                            for(int i = 0 ; i < row.length ; i++){
                                row[i] = attribute.get(i) + ":" + row[i];
                            }
                            data.add(row);
                        }
                    }
                }
                br.close();
            } catch (IOException e1) {
                e1.printStackTrace();
            }
        }
    
        public void init() {
            ID3 inst = new ID3();
            inst.getLearnData(new File("/Users/zxc/Documents/app.arff"));
            inst.setDec("precision");
    
            LinkedList<Integer> ll=new LinkedList<Integer>();
            for(int i=0;i<inst.attribute.size();i++){
                if(i!=inst.decatt)
                    ll.add(i);
            }
            System.out.println("有多少个属性ll:" + ll);
            ArrayList<Integer> al=new ArrayList<Integer>();
            for(int i=0;i<inst.data.size();i++){
                al.add(i);
            }
            System.out.println("有多少条值al:" + al);
            inst.buildDT(rootNode, al, ll);
            println(rootNode, 0);
        }
    
        //读取arff文件,给attribute、attributevalue、data赋值
        private void readAttribute(String dataS) {
            Pattern pattern = Pattern.compile(patternString);
            Matcher matcher = pattern.matcher(dataS);
            while (matcher.find()) {
                attribute.add(matcher.group(1).trim());
                String[] values = matcher.group(2).split(",");
                ArrayList<String> al = new ArrayList<String>(values.length);
                for (String value : values) {
                    al.add(value.trim());
                }
                attributevalue.add(al);
            }
        }
    
        private void readData(String dataS){
            String[] line = dataS.split("\n");
            for (int i = 0 ; i < line.length ; i++) {
                String[] row = line[i].split(",");
                data.add(row);
            }
        }
    
        //设置决策变量
        private void setDec(int n) {
            if (n < 0 || n >= attribute.size()) {
                System.err.println("决策变量指定错误。");
                System.exit(2);
            }
            decatt = n;
        }
        private void setDec(String name) {
            int n = attribute.indexOf(name);
            setDec(n);
        }
    
        //给一个样本(数组中是各种情况的计数),计算它的熵
        private double getEntropy(int[] arr) {
            double entropy = 0.0;
            int sum = 0;
            for (int i = 0; i < arr.length; i++) {
                entropy -= arr[i] * Math.log(arr[i]+Double.MIN_VALUE)/Math.log(2);
                sum += arr[i];
            }
            entropy += sum * Math.log(sum+Double.MIN_VALUE)/Math.log(2);
            entropy /= sum;
            return entropy;
        }
    
        // 给定原始数据的子集(subset中存储行号),当以第index个属性为节点时计算它的信息熵
        private double calNodeEntropy(ArrayList<Integer> subset, int index) {
            int sum = subset.size();
            double entropy = 0.0;
            int[][] info = new int[attributevalue.get(index).size()][];
            for (int i = 0; i < info.length; i++)
                info[i] = new int[attributevalue.get(decatt).size()];
            int[] count = new int[attributevalue.get(index).size()];
            for (int i = 0; i < sum; i++) {
                int n = subset.get(i);
                String nodevalue = data.get(n)[index];
                int nodeind = attributevalue.get(index).indexOf(nodevalue);
                count[nodeind]++;
    
                String decvalue = data.get(n)[decatt];
                int decind = attributevalue.get(decatt).indexOf(decvalue);
                info[nodeind][decind]++;
    
            }
            for (int i = 0; i < info.length; i++) {
                if(count[i] != 0){
                    entropy += getEntropy(info[i]) * count[i] / sum;
                }
            }
    
            //C45 -start
            //H(OUTLOOK) = - 5/14 * log2(5/14) - 5/14 * log2(5/14) - 4/14 * log2(4/14) = 1.577406282852345
            //Info(OUTLOOK) = 5/14 * [- 2/5 * log2(2/5) – 3/5 * log2(3/5)] + 4/14 * [ - 4/4 * log2(4/4) - 0/4 * log2(0/4)] + 5/14 * [ - 3/5 * log2(3/5) – 2/5 * log2(2/5)] = 0.694
            //IGR(OUTLOOK) = Info(OUTLOOK) / H(OUTLOOK) = 0.246/1.577406282852345 = 0.15595221261270145
            //c45算法改进 - start
    //      double x = 0;
    //      for (int i = 0; i < info.length; i++) {
    //          x -= (count[i] / sum) * Math.log(count[i] / sum) / Math.log(2);
    //      }
    //      entropy = entropy / x;
            //C45 - end
    
            return entropy;
        }
    
        private boolean infoPure(ArrayList<Integer> subset) {
            String value = data.get(subset.get(0))[decatt];
            for (int i = 1; i < subset.size(); i++) {
                String next=data.get(subset.get(i))[decatt];
                //equals表示对象内容相同,==表示两个对象指向的是同一片内存
                if (!value.equals(next))
                    return false;
            }
            return true;
        }
    
        // 找到子集中决策属性最普遍的值
        public String getGeneralAttr(ArrayList<Integer> subset)
        {
            String result = null;
            int valueCount = attributevalue.get(decatt).size();
            String[] values = new String[valueCount];
            for (int i = 0; i < values.length; i++)
                values[i] = attributevalue.get(decatt).get(i);
            int[] counts = new int[valueCount];
            for (int i = 0; i < counts.length; i++)
                counts[i] = 0;
            for (int i = 0; i < subset.size(); i++)
            {
                for (int j = 0; j < values.length; j++)
                {
                    if (data.get(subset.get(i))[decatt].equals(values[j]))
                        counts[j]++;
                }
            }
            int max = -1;
            int k = 0;
            for (int i = 0; i < counts.length; i++)
            {
                if (max < counts[i])
                {
                    k = i;
                    max = counts[i];
                }
            }
            result = attributevalue.get(decatt).get(k);
            return result;
        }
    
    
        // 构建决策树
        private void buildDT(DecisionTreeNode node, ArrayList<Integer> subset,
                             LinkedList<Integer> selatt) {
    
            //准确落实叶子节点
            if (infoPure(subset)) {
    //            System.out.println(data.get(subset.get(0))[decatt]);
                node.nodeKey = data.get(subset.get(0))[decatt];
                return;
            }
    
            //概率选择叶子节点
            if(selatt == null || selatt.size() == 0){
    //          System.out.println(probability(subset));
                node.nodeKey = getGeneralAttr(subset);
                return;
            }
    
            //找出熵衰减最快的节点
            int minIndex = -1;
            double minEntropy = Double.MAX_VALUE;
            for (int i = 0; i < selatt.size(); i++) {
                if (i == decatt)
                    continue;
                //找出节点信息熵最小的,熵越小信息约有序
                double entropy = calNodeEntropy(subset, selatt.get(i));
                if (entropy < minEntropy) {
                    minIndex = selatt.get(i);
                    minEntropy = entropy;
                }
            }
    
            //根据属性位置,找到属性名称
            String nodeName = attribute.get(minIndex);
            //获得这个属性的属性值
            ArrayList<String> attvalues = attributevalue.get(minIndex);
            //删除这一列的数据纪录,就是删除一个节点,这个节点已经有值了
            selatt.remove(new Integer(minIndex));
    
            for (String val : attvalues) {
                System.out.println("nodeName:"+nodeName +"  value:" + val);
                DecisionTreeNode newNode = new DecisionTreeNode();
                newNode.nodeName = nodeName;
                newNode.nodeValue = val;
                node.map.put(nodeName + val, newNode);
    
                //缩小数据查找范围
                ArrayList<Integer> al = new ArrayList<>();
                //一行一行获取数据
                for (int i = 0; i < subset.size(); i++) {
                    //数据分类梳理,范围进行缩小
                    if (data.get(subset.get(i))[minIndex].equals(val)) {
                        al.add(subset.get(i));
                    }
                }
                //删除这一列的数据纪录,传入下一条目继续递归获取数据
                LinkedList<Integer> newLinkedList = new LinkedList<>();
                newLinkedList.addAll(selatt);
                if(minEntropy != 0 && al.size() > 0){
                    buildDT(newNode, al, newLinkedList);
                }else {
                    buildDT(newNode, subset, null);
                }
            }
        }
    
    //    String learnAttribute ="" +
    //            "@attribute outlook {sunny, overcast, rainy}\n" +
    //            "@attribute temperature {hot, mild, cool}\n" +
    //            "@attribute humidity {high, normal}\n" +
    //            "@attribute windy {TRUE, FALSE}\n" +
    //            "@attribute play {yes, no}";
    //
    //    String learnData =
    //                    "sunny,hot,high,FALSE,no\n" +
    //                    "sunny,hot,high,TRUE,no\n" +
    //                    "sunny,mild,high,FALSE,no\n" +
    //                    "sunny,cool,normal,FALSE,yes\n" +
    //                    "sunny,mild,normal,TRUE,yes\n" +
    //                    "overcast,hot,high,FALSE,yes\n" +
    //                    "overcast,cool,normal,TRUE,yes\n" +
    //                    "overcast,mild,high,TRUE,yes\n" +
    //                    "overcast,hot,normal,FALSE,yes\n" +
    //                    "rainy,mild,high,FALSE,yes\n" +
    //                    "rainy,mild,normal,FALSE,yes\n" +
    //                    "rainy,mild,high,TRUE,no\n" +
    //                    "rainy,cool,normal,FALSE,yes\n" +
    //                    "rainy,cool,normal,TRUE,no";
    
    
        private void println(DecisionTreeNode node, int tree){
            //打印这一层级的数据
            for(int i = 0 ; i < tree ; i++){
                System.out.print(" ");
            }
            System.out.println("<"+node.nodeName + " value=" + node.nodeValue + ">");
            if(node.map.size() > 0){
               Set<String> stringSet = node.map.keySet();
               for(String s : stringSet){
                   println(node.map.get(s), ++tree);
               }
            }else {
                for(int i = 0 ; i < tree ; i++){
                    System.out.print("  ");
                }
    
                System.out.println(node.nodeKey);
            }
    
            for(int i = 0 ; i < tree ; i++){
                System.out.print("  ");
            }
            System.out.println("</" + node.nodeName + ">");
        }
    
        public String result(String s){
            String[] findS = s.split(",");
            List<String> list = new ArrayList<>();
            for(int i = 0 ; i < findS.length ; i++){
                list.add(attribute.get(i) + ":" + findS[i]);
            }
            return result(rootNode, list);
        }
    
        //节点查找
        private String result(DecisionTreeNode rootNode, List<String> findS){
            if(rootNode.map.size() > 0){
                Set<String> keySet = rootNode.map.keySet();
                for(String key : keySet){
                    System.out.println("key:" + key);
                    if(findS.contains(key)){
                       System.out.println("找到:" + key);
                       findS.remove(key);
                       return result(rootNode.map.get(key), findS);
                    }
                }
                return null;
            }else {
                System.out.println("result:" + rootNode.nodeKey);
               return rootNode.nodeKey;
            }
        }
    }
    
    //生成树
    class DecisionTreeNode{
        String nodeName;
        String nodeValue;
        String nodeKey;
        Map<String, DecisionTreeNode> map = new HashMap<>();
    }
    

    相关文章

      网友评论

        本文标题:人工智能-决策树ID3算法

        本文链接:https://www.haomeiwen.com/subject/pbykuxtx.html