决策树是机器学习中非常经典的一类学习算法,它通过树的结构,利用树的分支来表示对样本特征的判断规则,从树的叶子节点所包含的训练样本中得到预测值。决策树如何生成决定了所能处理的数据类型和预测性能。主要的决策树算法包括ID3,C4.5, CART等。
1,ID3
ID3是由 Ross Quinlan在1986年提出的一种构造决策树的方法。用于处理标称型数据集,其构造过程如下:
输入训练数据是一组带有类别标记的样本,构造的结果是一棵多叉树。树的分支节点一般表示为一个逻辑判断,如形式为a=aj的逻辑判断,其中a是属性,aj是该属性的所有取值。
在该节点上选取能对该节点处的训练数据进行最优划分的属性。最后划分的标准是信息增益(Information Gain),即划分前后数据集的熵的差异。
如果在该节点的父节点或者祖先中用了某个属性,则这个用过的属性就不再使用。选择好最优属性后,假设该属性有N个取值,则为该节点建立N个分支,将相应的训练数据传递到这N个分支中,递归进行,停止条件为:
(1)该节点的所有样本属于同一类,该节点成为叶节点,存放相应的类别。
(2)所有的属性都已被父节点或祖先使用。这种情况下,该节点成为叶节点,并以样本中元组个数最多的类别作为类别标记,同时也可以存放该结点样本的类别分布。
ID3的特点是:(1),容易造成过度拟合。(2), 使用标称型数据,但是很难处理连续型数据。
2, C4.5
C4.5是对ID3的改进,其基本过程与ID3类似,改进的地方在于:
(1)既能处理标称型数据,又能连续型数据。为了处理连续型数据,该算法在相应的节点使用一个属性的阈值,利用阈值将样本划分成两部分。
(2)能处理缺失了一些属性的数据。该算法允许属性值缺失时被标记为?,属性值缺失的样本在计算熵增益时被忽略。
(3)构造完成后可以剪枝。合并相邻的无法产生大量信息增益的叶节点,消除过渡匹配问题。
3,CART
CART称为分类决策树,classification and regression tree,既能处理分类问题,又能处理回归问题。最初由Breiman提出。与ID3不能直接处理连续型特征不同的是,CART使用二元切分,即使用一个属性阈值对样本数据进行划分。划分的标准除了使用熵增益外,还有基尼纯净度(Gini impurity)和方差缩减(variance reduction)(用于回归)。
http://blog.csdn.net/taigw/article/details/44840771
package com.company;
import java.io.BufferedReader;
import java.io.File;
import java.io.FileReader;
import java.io.IOException;
import java.util.*;
import java.util.regex.Matcher;
import java.util.regex.Pattern;
public class ID3 {
private ArrayList<String> attribute = new ArrayList<>(); // 存储属性的名称
private ArrayList<ArrayList<String>> attributevalue = new ArrayList<>(); // 存储每个属性的取值
private ArrayList<String[]> data = new ArrayList<>();; // 原始数据
int decatt; // 决策变量在属性集中的索引
public static final String patternString = "@attribute(.*)[{](.*?)[}]";
private DecisionTreeNode rootNode;
public ID3() {
rootNode = new DecisionTreeNode();
rootNode.nodeName = "root";
rootNode.nodeValue = null;
}
public void getLearnData(File file){
try {
FileReader fr = new FileReader(file);
BufferedReader br = new BufferedReader(fr);
String line;
Pattern pattern = Pattern.compile(patternString);
while ((line = br.readLine()) != null) {
Matcher matcher = pattern.matcher(line);
if (matcher.find()) {
attribute.add(matcher.group(1).trim());
String[] values = matcher.group(2).split(",");
ArrayList<String> al = new ArrayList<String>(values.length);
for (String value : values) {
al.add(matcher.group(1).trim() + ":" + value.trim());
}
attributevalue.add(al);
} else if (line.startsWith("@data")) {
while ((line = br.readLine()) != null) {
if("".equals(line))
continue;
String[] row = line.split(",");
for(int i = 0 ; i < row.length ; i++){
row[i] = attribute.get(i) + ":" + row[i];
}
data.add(row);
}
}
}
br.close();
} catch (IOException e1) {
e1.printStackTrace();
}
}
public void init() {
ID3 inst = new ID3();
inst.getLearnData(new File("/Users/zxc/Documents/app.arff"));
inst.setDec("precision");
LinkedList<Integer> ll=new LinkedList<Integer>();
for(int i=0;i<inst.attribute.size();i++){
if(i!=inst.decatt)
ll.add(i);
}
System.out.println("有多少个属性ll:" + ll);
ArrayList<Integer> al=new ArrayList<Integer>();
for(int i=0;i<inst.data.size();i++){
al.add(i);
}
System.out.println("有多少条值al:" + al);
inst.buildDT(rootNode, al, ll);
println(rootNode, 0);
}
//读取arff文件,给attribute、attributevalue、data赋值
private void readAttribute(String dataS) {
Pattern pattern = Pattern.compile(patternString);
Matcher matcher = pattern.matcher(dataS);
while (matcher.find()) {
attribute.add(matcher.group(1).trim());
String[] values = matcher.group(2).split(",");
ArrayList<String> al = new ArrayList<String>(values.length);
for (String value : values) {
al.add(value.trim());
}
attributevalue.add(al);
}
}
private void readData(String dataS){
String[] line = dataS.split("\n");
for (int i = 0 ; i < line.length ; i++) {
String[] row = line[i].split(",");
data.add(row);
}
}
//设置决策变量
private void setDec(int n) {
if (n < 0 || n >= attribute.size()) {
System.err.println("决策变量指定错误。");
System.exit(2);
}
decatt = n;
}
private void setDec(String name) {
int n = attribute.indexOf(name);
setDec(n);
}
//给一个样本(数组中是各种情况的计数),计算它的熵
private double getEntropy(int[] arr) {
double entropy = 0.0;
int sum = 0;
for (int i = 0; i < arr.length; i++) {
entropy -= arr[i] * Math.log(arr[i]+Double.MIN_VALUE)/Math.log(2);
sum += arr[i];
}
entropy += sum * Math.log(sum+Double.MIN_VALUE)/Math.log(2);
entropy /= sum;
return entropy;
}
// 给定原始数据的子集(subset中存储行号),当以第index个属性为节点时计算它的信息熵
private double calNodeEntropy(ArrayList<Integer> subset, int index) {
int sum = subset.size();
double entropy = 0.0;
int[][] info = new int[attributevalue.get(index).size()][];
for (int i = 0; i < info.length; i++)
info[i] = new int[attributevalue.get(decatt).size()];
int[] count = new int[attributevalue.get(index).size()];
for (int i = 0; i < sum; i++) {
int n = subset.get(i);
String nodevalue = data.get(n)[index];
int nodeind = attributevalue.get(index).indexOf(nodevalue);
count[nodeind]++;
String decvalue = data.get(n)[decatt];
int decind = attributevalue.get(decatt).indexOf(decvalue);
info[nodeind][decind]++;
}
for (int i = 0; i < info.length; i++) {
if(count[i] != 0){
entropy += getEntropy(info[i]) * count[i] / sum;
}
}
//C45 -start
//H(OUTLOOK) = - 5/14 * log2(5/14) - 5/14 * log2(5/14) - 4/14 * log2(4/14) = 1.577406282852345
//Info(OUTLOOK) = 5/14 * [- 2/5 * log2(2/5) – 3/5 * log2(3/5)] + 4/14 * [ - 4/4 * log2(4/4) - 0/4 * log2(0/4)] + 5/14 * [ - 3/5 * log2(3/5) – 2/5 * log2(2/5)] = 0.694
//IGR(OUTLOOK) = Info(OUTLOOK) / H(OUTLOOK) = 0.246/1.577406282852345 = 0.15595221261270145
//c45算法改进 - start
// double x = 0;
// for (int i = 0; i < info.length; i++) {
// x -= (count[i] / sum) * Math.log(count[i] / sum) / Math.log(2);
// }
// entropy = entropy / x;
//C45 - end
return entropy;
}
private boolean infoPure(ArrayList<Integer> subset) {
String value = data.get(subset.get(0))[decatt];
for (int i = 1; i < subset.size(); i++) {
String next=data.get(subset.get(i))[decatt];
//equals表示对象内容相同,==表示两个对象指向的是同一片内存
if (!value.equals(next))
return false;
}
return true;
}
// 找到子集中决策属性最普遍的值
public String getGeneralAttr(ArrayList<Integer> subset)
{
String result = null;
int valueCount = attributevalue.get(decatt).size();
String[] values = new String[valueCount];
for (int i = 0; i < values.length; i++)
values[i] = attributevalue.get(decatt).get(i);
int[] counts = new int[valueCount];
for (int i = 0; i < counts.length; i++)
counts[i] = 0;
for (int i = 0; i < subset.size(); i++)
{
for (int j = 0; j < values.length; j++)
{
if (data.get(subset.get(i))[decatt].equals(values[j]))
counts[j]++;
}
}
int max = -1;
int k = 0;
for (int i = 0; i < counts.length; i++)
{
if (max < counts[i])
{
k = i;
max = counts[i];
}
}
result = attributevalue.get(decatt).get(k);
return result;
}
// 构建决策树
private void buildDT(DecisionTreeNode node, ArrayList<Integer> subset,
LinkedList<Integer> selatt) {
//准确落实叶子节点
if (infoPure(subset)) {
// System.out.println(data.get(subset.get(0))[decatt]);
node.nodeKey = data.get(subset.get(0))[decatt];
return;
}
//概率选择叶子节点
if(selatt == null || selatt.size() == 0){
// System.out.println(probability(subset));
node.nodeKey = getGeneralAttr(subset);
return;
}
//找出熵衰减最快的节点
int minIndex = -1;
double minEntropy = Double.MAX_VALUE;
for (int i = 0; i < selatt.size(); i++) {
if (i == decatt)
continue;
//找出节点信息熵最小的,熵越小信息约有序
double entropy = calNodeEntropy(subset, selatt.get(i));
if (entropy < minEntropy) {
minIndex = selatt.get(i);
minEntropy = entropy;
}
}
//根据属性位置,找到属性名称
String nodeName = attribute.get(minIndex);
//获得这个属性的属性值
ArrayList<String> attvalues = attributevalue.get(minIndex);
//删除这一列的数据纪录,就是删除一个节点,这个节点已经有值了
selatt.remove(new Integer(minIndex));
for (String val : attvalues) {
System.out.println("nodeName:"+nodeName +" value:" + val);
DecisionTreeNode newNode = new DecisionTreeNode();
newNode.nodeName = nodeName;
newNode.nodeValue = val;
node.map.put(nodeName + val, newNode);
//缩小数据查找范围
ArrayList<Integer> al = new ArrayList<>();
//一行一行获取数据
for (int i = 0; i < subset.size(); i++) {
//数据分类梳理,范围进行缩小
if (data.get(subset.get(i))[minIndex].equals(val)) {
al.add(subset.get(i));
}
}
//删除这一列的数据纪录,传入下一条目继续递归获取数据
LinkedList<Integer> newLinkedList = new LinkedList<>();
newLinkedList.addAll(selatt);
if(minEntropy != 0 && al.size() > 0){
buildDT(newNode, al, newLinkedList);
}else {
buildDT(newNode, subset, null);
}
}
}
// String learnAttribute ="" +
// "@attribute outlook {sunny, overcast, rainy}\n" +
// "@attribute temperature {hot, mild, cool}\n" +
// "@attribute humidity {high, normal}\n" +
// "@attribute windy {TRUE, FALSE}\n" +
// "@attribute play {yes, no}";
//
// String learnData =
// "sunny,hot,high,FALSE,no\n" +
// "sunny,hot,high,TRUE,no\n" +
// "sunny,mild,high,FALSE,no\n" +
// "sunny,cool,normal,FALSE,yes\n" +
// "sunny,mild,normal,TRUE,yes\n" +
// "overcast,hot,high,FALSE,yes\n" +
// "overcast,cool,normal,TRUE,yes\n" +
// "overcast,mild,high,TRUE,yes\n" +
// "overcast,hot,normal,FALSE,yes\n" +
// "rainy,mild,high,FALSE,yes\n" +
// "rainy,mild,normal,FALSE,yes\n" +
// "rainy,mild,high,TRUE,no\n" +
// "rainy,cool,normal,FALSE,yes\n" +
// "rainy,cool,normal,TRUE,no";
private void println(DecisionTreeNode node, int tree){
//打印这一层级的数据
for(int i = 0 ; i < tree ; i++){
System.out.print(" ");
}
System.out.println("<"+node.nodeName + " value=" + node.nodeValue + ">");
if(node.map.size() > 0){
Set<String> stringSet = node.map.keySet();
for(String s : stringSet){
println(node.map.get(s), ++tree);
}
}else {
for(int i = 0 ; i < tree ; i++){
System.out.print(" ");
}
System.out.println(node.nodeKey);
}
for(int i = 0 ; i < tree ; i++){
System.out.print(" ");
}
System.out.println("</" + node.nodeName + ">");
}
public String result(String s){
String[] findS = s.split(",");
List<String> list = new ArrayList<>();
for(int i = 0 ; i < findS.length ; i++){
list.add(attribute.get(i) + ":" + findS[i]);
}
return result(rootNode, list);
}
//节点查找
private String result(DecisionTreeNode rootNode, List<String> findS){
if(rootNode.map.size() > 0){
Set<String> keySet = rootNode.map.keySet();
for(String key : keySet){
System.out.println("key:" + key);
if(findS.contains(key)){
System.out.println("找到:" + key);
findS.remove(key);
return result(rootNode.map.get(key), findS);
}
}
return null;
}else {
System.out.println("result:" + rootNode.nodeKey);
return rootNode.nodeKey;
}
}
}
//生成树
class DecisionTreeNode{
String nodeName;
String nodeValue;
String nodeKey;
Map<String, DecisionTreeNode> map = new HashMap<>();
}
网友评论