美文网首页算法
人工智能-决策树ID3算法

人工智能-决策树ID3算法

作者: zhangxuanchen | 来源:发表于2017-10-16 13:07 被阅读30次

决策树是机器学习中非常经典的一类学习算法,它通过树的结构,利用树的分支来表示对样本特征的判断规则,从树的叶子节点所包含的训练样本中得到预测值。决策树如何生成决定了所能处理的数据类型和预测性能。主要的决策树算法包括ID3,C4.5, CART等。
1,ID3

ID3是由 Ross Quinlan在1986年提出的一种构造决策树的方法。用于处理标称型数据集,其构造过程如下:
输入训练数据是一组带有类别标记的样本,构造的结果是一棵多叉树。树的分支节点一般表示为一个逻辑判断,如形式为a=aj的逻辑判断,其中a是属性,aj是该属性的所有取值。
在该节点上选取能对该节点处的训练数据进行最优划分的属性。最后划分的标准是信息增益(Information Gain),即划分前后数据集的熵的差异。
如果在该节点的父节点或者祖先中用了某个属性,则这个用过的属性就不再使用。选择好最优属性后,假设该属性有N个取值,则为该节点建立N个分支,将相应的训练数据传递到这N个分支中,递归进行,停止条件为:
(1)该节点的所有样本属于同一类,该节点成为叶节点,存放相应的类别。
(2)所有的属性都已被父节点或祖先使用。这种情况下,该节点成为叶节点,并以样本中元组个数最多的类别作为类别标记,同时也可以存放该结点样本的类别分布。
ID3的特点是:(1),容易造成过度拟合。(2), 使用标称型数据,但是很难处理连续型数据。
2, C4.5

C4.5是对ID3的改进,其基本过程与ID3类似,改进的地方在于:
(1)既能处理标称型数据,又能连续型数据。为了处理连续型数据,该算法在相应的节点使用一个属性的阈值,利用阈值将样本划分成两部分。
(2)能处理缺失了一些属性的数据。该算法允许属性值缺失时被标记为?,属性值缺失的样本在计算熵增益时被忽略。
(3)构造完成后可以剪枝。合并相邻的无法产生大量信息增益的叶节点,消除过渡匹配问题。
3,CART

CART称为分类决策树,classification and regression tree,既能处理分类问题,又能处理回归问题。最初由Breiman提出。与ID3不能直接处理连续型特征不同的是,CART使用二元切分,即使用一个属性阈值对样本数据进行划分。划分的标准除了使用熵增益外,还有基尼纯净度(Gini impurity)和方差缩减(variance reduction)(用于回归)。

http://blog.csdn.net/taigw/article/details/44840771

package com.company;

import java.io.BufferedReader;
import java.io.File;
import java.io.FileReader;
import java.io.IOException;
import java.util.*;
import java.util.regex.Matcher;
import java.util.regex.Pattern;

public class ID3 {

    private ArrayList<String> attribute = new ArrayList<>(); // 存储属性的名称
    private ArrayList<ArrayList<String>> attributevalue = new ArrayList<>(); // 存储每个属性的取值
    private ArrayList<String[]> data = new ArrayList<>();; // 原始数据
    int decatt; // 决策变量在属性集中的索引
    public static final String patternString = "@attribute(.*)[{](.*?)[}]";
    private DecisionTreeNode rootNode;

    public ID3() {
        rootNode = new DecisionTreeNode();
        rootNode.nodeName = "root";
        rootNode.nodeValue = null;
    }

    public void getLearnData(File file){
        try {
            FileReader fr = new FileReader(file);
            BufferedReader br = new BufferedReader(fr);
            String line;
            Pattern pattern = Pattern.compile(patternString);
            while ((line = br.readLine()) != null) {
                Matcher matcher = pattern.matcher(line);
                if (matcher.find()) {
                    attribute.add(matcher.group(1).trim());
                    String[] values = matcher.group(2).split(",");
                    ArrayList<String> al = new ArrayList<String>(values.length);
                    for (String value : values) {
                        al.add(matcher.group(1).trim() + ":" + value.trim());
                    }
                    attributevalue.add(al);
                } else if (line.startsWith("@data")) {
                    while ((line = br.readLine()) != null) {
                        if("".equals(line))
                            continue;
                        String[] row = line.split(",");
                        for(int i = 0 ; i < row.length ; i++){
                            row[i] = attribute.get(i) + ":" + row[i];
                        }
                        data.add(row);
                    }
                }
            }
            br.close();
        } catch (IOException e1) {
            e1.printStackTrace();
        }
    }

    public void init() {
        ID3 inst = new ID3();
        inst.getLearnData(new File("/Users/zxc/Documents/app.arff"));
        inst.setDec("precision");

        LinkedList<Integer> ll=new LinkedList<Integer>();
        for(int i=0;i<inst.attribute.size();i++){
            if(i!=inst.decatt)
                ll.add(i);
        }
        System.out.println("有多少个属性ll:" + ll);
        ArrayList<Integer> al=new ArrayList<Integer>();
        for(int i=0;i<inst.data.size();i++){
            al.add(i);
        }
        System.out.println("有多少条值al:" + al);
        inst.buildDT(rootNode, al, ll);
        println(rootNode, 0);
    }

    //读取arff文件,给attribute、attributevalue、data赋值
    private void readAttribute(String dataS) {
        Pattern pattern = Pattern.compile(patternString);
        Matcher matcher = pattern.matcher(dataS);
        while (matcher.find()) {
            attribute.add(matcher.group(1).trim());
            String[] values = matcher.group(2).split(",");
            ArrayList<String> al = new ArrayList<String>(values.length);
            for (String value : values) {
                al.add(value.trim());
            }
            attributevalue.add(al);
        }
    }

    private void readData(String dataS){
        String[] line = dataS.split("\n");
        for (int i = 0 ; i < line.length ; i++) {
            String[] row = line[i].split(",");
            data.add(row);
        }
    }

    //设置决策变量
    private void setDec(int n) {
        if (n < 0 || n >= attribute.size()) {
            System.err.println("决策变量指定错误。");
            System.exit(2);
        }
        decatt = n;
    }
    private void setDec(String name) {
        int n = attribute.indexOf(name);
        setDec(n);
    }

    //给一个样本(数组中是各种情况的计数),计算它的熵
    private double getEntropy(int[] arr) {
        double entropy = 0.0;
        int sum = 0;
        for (int i = 0; i < arr.length; i++) {
            entropy -= arr[i] * Math.log(arr[i]+Double.MIN_VALUE)/Math.log(2);
            sum += arr[i];
        }
        entropy += sum * Math.log(sum+Double.MIN_VALUE)/Math.log(2);
        entropy /= sum;
        return entropy;
    }

    // 给定原始数据的子集(subset中存储行号),当以第index个属性为节点时计算它的信息熵
    private double calNodeEntropy(ArrayList<Integer> subset, int index) {
        int sum = subset.size();
        double entropy = 0.0;
        int[][] info = new int[attributevalue.get(index).size()][];
        for (int i = 0; i < info.length; i++)
            info[i] = new int[attributevalue.get(decatt).size()];
        int[] count = new int[attributevalue.get(index).size()];
        for (int i = 0; i < sum; i++) {
            int n = subset.get(i);
            String nodevalue = data.get(n)[index];
            int nodeind = attributevalue.get(index).indexOf(nodevalue);
            count[nodeind]++;

            String decvalue = data.get(n)[decatt];
            int decind = attributevalue.get(decatt).indexOf(decvalue);
            info[nodeind][decind]++;

        }
        for (int i = 0; i < info.length; i++) {
            if(count[i] != 0){
                entropy += getEntropy(info[i]) * count[i] / sum;
            }
        }

        //C45 -start
        //H(OUTLOOK) = - 5/14 * log2(5/14) - 5/14 * log2(5/14) - 4/14 * log2(4/14) = 1.577406282852345
        //Info(OUTLOOK) = 5/14 * [- 2/5 * log2(2/5) – 3/5 * log2(3/5)] + 4/14 * [ - 4/4 * log2(4/4) - 0/4 * log2(0/4)] + 5/14 * [ - 3/5 * log2(3/5) – 2/5 * log2(2/5)] = 0.694
        //IGR(OUTLOOK) = Info(OUTLOOK) / H(OUTLOOK) = 0.246/1.577406282852345 = 0.15595221261270145
        //c45算法改进 - start
//      double x = 0;
//      for (int i = 0; i < info.length; i++) {
//          x -= (count[i] / sum) * Math.log(count[i] / sum) / Math.log(2);
//      }
//      entropy = entropy / x;
        //C45 - end

        return entropy;
    }

    private boolean infoPure(ArrayList<Integer> subset) {
        String value = data.get(subset.get(0))[decatt];
        for (int i = 1; i < subset.size(); i++) {
            String next=data.get(subset.get(i))[decatt];
            //equals表示对象内容相同,==表示两个对象指向的是同一片内存
            if (!value.equals(next))
                return false;
        }
        return true;
    }

    // 找到子集中决策属性最普遍的值
    public String getGeneralAttr(ArrayList<Integer> subset)
    {
        String result = null;
        int valueCount = attributevalue.get(decatt).size();
        String[] values = new String[valueCount];
        for (int i = 0; i < values.length; i++)
            values[i] = attributevalue.get(decatt).get(i);
        int[] counts = new int[valueCount];
        for (int i = 0; i < counts.length; i++)
            counts[i] = 0;
        for (int i = 0; i < subset.size(); i++)
        {
            for (int j = 0; j < values.length; j++)
            {
                if (data.get(subset.get(i))[decatt].equals(values[j]))
                    counts[j]++;
            }
        }
        int max = -1;
        int k = 0;
        for (int i = 0; i < counts.length; i++)
        {
            if (max < counts[i])
            {
                k = i;
                max = counts[i];
            }
        }
        result = attributevalue.get(decatt).get(k);
        return result;
    }


    // 构建决策树
    private void buildDT(DecisionTreeNode node, ArrayList<Integer> subset,
                         LinkedList<Integer> selatt) {

        //准确落实叶子节点
        if (infoPure(subset)) {
//            System.out.println(data.get(subset.get(0))[decatt]);
            node.nodeKey = data.get(subset.get(0))[decatt];
            return;
        }

        //概率选择叶子节点
        if(selatt == null || selatt.size() == 0){
//          System.out.println(probability(subset));
            node.nodeKey = getGeneralAttr(subset);
            return;
        }

        //找出熵衰减最快的节点
        int minIndex = -1;
        double minEntropy = Double.MAX_VALUE;
        for (int i = 0; i < selatt.size(); i++) {
            if (i == decatt)
                continue;
            //找出节点信息熵最小的,熵越小信息约有序
            double entropy = calNodeEntropy(subset, selatt.get(i));
            if (entropy < minEntropy) {
                minIndex = selatt.get(i);
                minEntropy = entropy;
            }
        }

        //根据属性位置,找到属性名称
        String nodeName = attribute.get(minIndex);
        //获得这个属性的属性值
        ArrayList<String> attvalues = attributevalue.get(minIndex);
        //删除这一列的数据纪录,就是删除一个节点,这个节点已经有值了
        selatt.remove(new Integer(minIndex));

        for (String val : attvalues) {
            System.out.println("nodeName:"+nodeName +"  value:" + val);
            DecisionTreeNode newNode = new DecisionTreeNode();
            newNode.nodeName = nodeName;
            newNode.nodeValue = val;
            node.map.put(nodeName + val, newNode);

            //缩小数据查找范围
            ArrayList<Integer> al = new ArrayList<>();
            //一行一行获取数据
            for (int i = 0; i < subset.size(); i++) {
                //数据分类梳理,范围进行缩小
                if (data.get(subset.get(i))[minIndex].equals(val)) {
                    al.add(subset.get(i));
                }
            }
            //删除这一列的数据纪录,传入下一条目继续递归获取数据
            LinkedList<Integer> newLinkedList = new LinkedList<>();
            newLinkedList.addAll(selatt);
            if(minEntropy != 0 && al.size() > 0){
                buildDT(newNode, al, newLinkedList);
            }else {
                buildDT(newNode, subset, null);
            }
        }
    }

//    String learnAttribute ="" +
//            "@attribute outlook {sunny, overcast, rainy}\n" +
//            "@attribute temperature {hot, mild, cool}\n" +
//            "@attribute humidity {high, normal}\n" +
//            "@attribute windy {TRUE, FALSE}\n" +
//            "@attribute play {yes, no}";
//
//    String learnData =
//                    "sunny,hot,high,FALSE,no\n" +
//                    "sunny,hot,high,TRUE,no\n" +
//                    "sunny,mild,high,FALSE,no\n" +
//                    "sunny,cool,normal,FALSE,yes\n" +
//                    "sunny,mild,normal,TRUE,yes\n" +
//                    "overcast,hot,high,FALSE,yes\n" +
//                    "overcast,cool,normal,TRUE,yes\n" +
//                    "overcast,mild,high,TRUE,yes\n" +
//                    "overcast,hot,normal,FALSE,yes\n" +
//                    "rainy,mild,high,FALSE,yes\n" +
//                    "rainy,mild,normal,FALSE,yes\n" +
//                    "rainy,mild,high,TRUE,no\n" +
//                    "rainy,cool,normal,FALSE,yes\n" +
//                    "rainy,cool,normal,TRUE,no";


    private void println(DecisionTreeNode node, int tree){
        //打印这一层级的数据
        for(int i = 0 ; i < tree ; i++){
            System.out.print(" ");
        }
        System.out.println("<"+node.nodeName + " value=" + node.nodeValue + ">");
        if(node.map.size() > 0){
           Set<String> stringSet = node.map.keySet();
           for(String s : stringSet){
               println(node.map.get(s), ++tree);
           }
        }else {
            for(int i = 0 ; i < tree ; i++){
                System.out.print("  ");
            }

            System.out.println(node.nodeKey);
        }

        for(int i = 0 ; i < tree ; i++){
            System.out.print("  ");
        }
        System.out.println("</" + node.nodeName + ">");
    }

    public String result(String s){
        String[] findS = s.split(",");
        List<String> list = new ArrayList<>();
        for(int i = 0 ; i < findS.length ; i++){
            list.add(attribute.get(i) + ":" + findS[i]);
        }
        return result(rootNode, list);
    }

    //节点查找
    private String result(DecisionTreeNode rootNode, List<String> findS){
        if(rootNode.map.size() > 0){
            Set<String> keySet = rootNode.map.keySet();
            for(String key : keySet){
                System.out.println("key:" + key);
                if(findS.contains(key)){
                   System.out.println("找到:" + key);
                   findS.remove(key);
                   return result(rootNode.map.get(key), findS);
                }
            }
            return null;
        }else {
            System.out.println("result:" + rootNode.nodeKey);
           return rootNode.nodeKey;
        }
    }
}

//生成树
class DecisionTreeNode{
    String nodeName;
    String nodeValue;
    String nodeKey;
    Map<String, DecisionTreeNode> map = new HashMap<>();
}

相关文章

网友评论

    本文标题:人工智能-决策树ID3算法

    本文链接:https://www.haomeiwen.com/subject/pbykuxtx.html