ids3

作者: everyhook | 来源:发表于2017-12-08 17:29 被阅读0次

package ml.package4;

import org.junit.Test;

import java.util.*;

import static java.util.stream.Collectors.groupingBy;

/**

  • Created by on 2017/12/8.
    */
    public class test {
    @Test
    public void testId3() {
    String[][] data = {
    {"青年", "否", "否", "一般", "否"},
    {"青年", "否", "否", "好", "否"},
    {"青年", "是", "否", "好", "是"},
    {"青年", "是", "是", "一般", "是"},
    {"青年", "否", "否", "一般", "否"},
    {"中年", "否", "否", "一般", "否"},
    {"中年", "否", "否", "好", "否"},
    {"中年", "是", "是", "好", "是"},
    {"中年", "否", "是", "非常好", "是"},
    {"中年", "否", "是", "非常好", "是"},
    {"老年", "否", "是", "非常好", "是"},
    {"老年", "否", "是", "好", "是"},
    {"老年", "是", "否", "好", "是"},
    {"老年", "是", "否", "非常好", "是"},
    {"老年", "否", "否", "一般", "否"}
    };
    List<String> title = Arrays.asList("年龄", "有工作", "有自己的房子", "信贷情况", "类别");
    ID3Node node = buildId3Tree(data, title);
    System.out.println(node);
    }
    private String[][] calcData(String[][] data, int column, String val) {
    Object[] tp = Arrays.stream(data).filter(o -> o[column].equals(val)).toArray();
    String[][] rs = new String[tp.length][data[0].length - 1];
    for (int i = 0; i < tp.length; i++) {
    String[] row = (String[]) tp[i];
    for (int j = 0, rsIndex = 0; j < row.length; j++) {
    if (j != column) {
    rs[i][rsIndex] = row[j];
    rsIndex++;
    }
    }
    }
    return rs;
    }
    private ID3Node buildId3Tree(String[][] data, List<String> title) {
    ID3Node id3Node = new ID3Node();
    final double[] val = {0, 0, 0, 0, 0};
    Arrays.stream(data).collect(groupingBy(o -> o[data[0].length - 1]))
    .forEach((name, list) -> val[data[0].length - 1] -= 1d * list.size() / data.length * ln(1d * list.size() / data.length));
    if (val[data[0].length - 1] == 0 || title.size() == 1) {
    id3Node.title = title.get(title.size() - 1);
    id3Node.val = data[0][data[0].length - 1];
    return id3Node;
    }
    List<List<String>> ids = new ArrayList<>();
    for (final int[] i = {0}; i[0] < data[0].length - 1; i[0]++) {
    List<String> id = new ArrayList<>();
    Arrays.stream(data).collect(groupingBy(o -> o[i[0]]))
    .forEach((name, list) -> {
    id.add(name);
    final double[] tp = {0};
    list.stream().collect(groupingBy(oo -> oo[data[0].length - 1]))
    .forEach(
    (_name, _list) -> tp[0] -= 1d * _list.size() / list.size() * ln(1d * _list.size() / list.size())
    );
    val[i[0]] += tp[0] * list.size() / data.length;
    });
    ids.add(id);
    }
    double v = -Double.MAX_VALUE;
    int index = -1;
    for (int i = 0; i < data[0].length - 1; i++) {
    System.out.println(title.get(i) + "->" + (val[data[0].length - 1] - val[i]));
    if (v < val[data[0].length - 1] - val[i]) {
    v = val[data[0].length - 1] - val[i];
    index = i;
    }
    }
    System.out.println(title.get(index) + " ->" + ids.get(index));
    id3Node.title = title.get(index);
    if (id3Node.children == null)
    id3Node.children = new HashMap<>();
    List<String> _title = new ArrayList<>();
    for (String str : title)
    if (!str.equals(title.get(index)))
    _title.add(str);
    for (String str : ids.get(index))
    id3Node.children.put(str, buildId3Tree(calcData(data, index, str), _title));
    return id3Node;
    }
    private double ln(double v) {
    if (v == 0d)
    return 0d;
    return Math.log(v) / Math.log(2d);
    }
    private class ID3Node {
    String val;
    String title;
    Map<String, ID3Node> children;
    }
    }

相关文章

  • ids3

    package ml.package4; import org.junit.Test; import java.u...

网友评论

      本文标题:ids3

      本文链接:https://www.haomeiwen.com/subject/asxyixtx.html