美文网首页
Kmeans算法

Kmeans算法

作者: 致虑 | 来源:发表于2018-08-30 20:06 被阅读0次

    Kmeans算法是最为经典的基于划分的聚类方法,是十大经典数据挖掘算法之一。Kmeans算法的基本思想是:以空间中k个点为中心进行聚类,对最靠近他们的对象归类。通过迭代的方法,逐次更新各聚类中心的值,直至得到最好的聚类结果。

    假设要把样本集分为K个类别,算法描述如下:

    (1)适当选择K个聚类的初始中心;

    (2)在第K次迭代中,对任意一个样本,求其到K个中心的距离,将该样本归到距离最短的中心所在的簇;

    (3)重新计算各个簇(聚类)的中心;

    (4)对于所有的K个聚类中心,如果利用(2)(3)反复迭代,重新计算新旧中心的距离,若距离不变或小于某个阀值,则迭代结束。

    该算法的最大优势在于简洁和快速。算法的关键在于初始中心的选择和距离公式。

    这里最关键的地方就是初始中心K的选择,这里的选择好坏会很大程度上影响最终聚类的结果。

    算法实现:下面简单利用距离计算方式演示聚类效果:

    package com.wxshi.kmean;
    
    import java.io.BufferedReader;
    import java.io.FileInputStream;
    import java.io.IOException;
    import java.io.InputStreamReader;
    import java.util.ArrayList;
    import java.util.List;
    
    /**
     * Kmeans算法演示
     * @author wxshi
     *
     */
    public class Kmeans {
    
        private List<ArrayList<Double>> centers;
        private List<ArrayList<Double>> newCenters;
        private List<ArrayList<ArrayList<Double>>> clusterList;
        private int clusterNum = 5; //默认聚类的个数
    
        /**
         * 默认构造不对外抛
         */
        private Kmeans(){
    
        }
    
        public Kmeans(int clusterNum){
            if(clusterNum<=0){
                clusterNum = 5;
            }
            this.clusterNum = clusterNum;
            centers = new ArrayList<ArrayList<Double>>();
            newCenters = new ArrayList<ArrayList<Double>>();
            clusterList = new ArrayList<ArrayList<ArrayList<Double>>>();
        }
    
        /**
         * 初始化簇,开始为空
         * @param args
         * @throws IOException
         */
        public List<ArrayList<ArrayList<Double>>> initclusterList() {
            clusterList = new ArrayList<ArrayList<ArrayList<Double>>>();
            for (int i = 0; i < clusterNum; i++) {
                clusterList.add(new ArrayList<ArrayList<Double>>());
            }
            return clusterList;
        }
    
        /**
         * 初始化聚类中心节点,随机选择
         * 这里随便选择几个
         * @param dataList
         */
        private void initCenters(List<ArrayList<Double>> dataList){
            for (int i = 0; i < clusterNum; i++) {
                centers.add(dataList.get(i + 2));
                clusterList.add(new ArrayList<ArrayList<Double>>());
            }
        }
    
        /**
         * 新旧中心切换
         * 清空原来的簇中数据,重新放置数据
         */
        private void replaceCenters() {
            centers = new ArrayList<ArrayList<Double>>(newCenters);
            newCenters = new ArrayList<ArrayList<Double>>();
            initclusterList();
        }
    
        /**
         * 欧式距离计算
         * @param element1
         * @param element2
         * @return
         */
        private double distance(double element1,double element2){
            double distance = 0;
            distance = ((element1 - element2) / (element1 + element2)) * ((element1 - element2) / (element1 + element2));
            return distance;
        }
    
        /**
         * 新旧聚类中心距离计算
         * @return
         */
        private double distanceOfCenters() {
            // 计算新旧中心之间的距离,当距离小于阈值时,聚类算法结束
            double distance = 0;
            for (int i = 0; i < clusterNum; i++) {
                for (int j = 0; j < centers.get(i).size(); j++) {// 计算两点之间的距离
                    distance += distance(centers.get(i).get(j) , newCenters.get(i).get(j));
                }
            }
            return distance;
        }
    
        /**
         * 重新计算聚类中心
         */
        private void newCenters() {
            for (int i = 0; i < clusterNum; i++) {
                int len = clusterList.get(i).size();
                ArrayList<Double> tmpList = new ArrayList<Double>();
                for (int j = 0; j < centers.get(0).size(); j++) {
                    double sum = 0;
                    for (int t = 0; t < len; t++) {
                        sum += clusterList.get(i).get(t).get(j);
                    }
                    tmpList.add(sum / len);
                }
                newCenters.add(tmpList);
            }
        }
    
        /**
         * 核心方法
         * 迭代簇,将距离最近的节点加入簇
         * @param dataList
         */
        private void intoCuster(List<ArrayList<Double>> dataList){
            for (int i = 0; i < dataList.size(); i++) {
                double minDistance = 99999999;
                int centerIndex = -1;
                for (int j = 0; j < clusterNum; j++) {// 计算最近距离
                    double currentDistance = 0;
                    for (int t = 0; t < centers.get(j).size(); t++) {// 计算两点之间的距离
                        currentDistance += distance(centers.get(j).get(t) , dataList.get(i).get(t)) ;
                    }
                    if (minDistance > currentDistance) {
                        minDistance = currentDistance;
                        centerIndex = j;
                    }
                }
                clusterList.get(centerIndex).add(dataList.get(i));
            }
        }
    
        /**
         * 读取文件,获取数据
         * @param dir
         * @return
         */
        public List<ArrayList<Double>> readFile(String dir) {
            List<ArrayList<Double>> dataList = new ArrayList<ArrayList<Double>>();
            try {
                BufferedReader br = new BufferedReader(new InputStreamReader(new FileInputStream("wine.txt")));
                String data = null;
                while ((data = br.readLine()) != null) {
                    String[] fields = data.split(",");
                    List<Double> tmpList = new ArrayList<Double>();
                    for (int i = 0; i < fields.length; i++) {
                        tmpList.add(Double.parseDouble(fields[i]));
                    }
                    dataList.add((ArrayList<Double>) tmpList);
                }
                br.close();
            } catch (IOException e) {
                e.printStackTrace();
            }
            return dataList;
        }
    
        /**
         * 打印结果
         */
        private void print() {
            for (int i = 0; i < clusterNum; i++) {
                System.out.println("\nCluster: " + (i + 1) + "   size: " + clusterList.get(i).size() + " :\n");
                for (int j = 0; j < clusterList.get(i).size(); j++) {
                    System.out.println(clusterList.get(i).get(j));
                }
            }
        }
    
        /**
         * @param args
         * @throws IOException
         */
        public static void main(String[] args) throws IOException {
    
            Kmeans kmeans = new Kmeans(5);
    
            // 读入原始数据
            List<ArrayList<Double>> dataList = kmeans.readFile("wine.txt");
    
            // 随机确定K个初始聚类中心
            kmeans.initCenters(dataList);
    
            // 进行若干次迭代,直到聚类中心稳定
            while (true) {
                kmeans.intoCuster(dataList);
                kmeans.newCenters();
                double distance = kmeans.distanceOfCenters();
    
                // 小于阈值时,结束循环
                if (distance == 0) {
                    break;
                }
                // 否则,新的中心来代替旧的中心,进行下一轮迭代
                else {
                    kmeans.replaceCenters();
                }
            }
    
            kmeans.print();
        }
    }
    
    image.gif

    下面代码摘自网页中,这里做个比较:

    import java.util.ArrayList;
    import java.util.Random;
    
    /**
     * K均值聚类算法
     */
    public class Kmeans2 {
    
        private int k;// 分成多少簇
        private int m;// 迭代次数
        private int dataSetLength;// 数据集元素个数,即数据集的长度
        private ArrayList<float[]> dataSet;// 数据集链表
        private ArrayList<float[]> center;// 中心链表
        private ArrayList<ArrayList<float[]>> cluster; // 簇
        private ArrayList<Float> jc;// 误差平方和,k越接近dataSetLength,误差越小
        private Random random;
    
        /**
         * 设置需分组的原始数据集
         * @param dataSet
         */
        public void setDataSet(ArrayList<float[]> dataSet) {
            this.dataSet = dataSet;
        }
    
        /**
         * 获取结果分组
         * @return 结果集
         */
        public ArrayList<ArrayList<float[]>> getCluster() {
            return cluster;
        }
    
        /**
         * 构造函数,传入需要分成的簇数量
         * @param k
         *            簇数量,若k<=0时,设置为1,若k大于数据源的长度时,置为数据源的长度
         */
        public Kmeans2(int k) {
            if (k <= 0) {
                k = 1;
            }
            this.k = k;
        }
    
        /**
         * 初始化
         */
        private void init() {
            m = 0;
            random = new Random();
            if (dataSet == null || dataSet.size() == 0) {
                initDataSet();
            }
            dataSetLength = dataSet.size();
            if (k > dataSetLength) {
                k = dataSetLength;
            }
            center = initCenters();
            cluster = initCluster();
            jc = new ArrayList<Float>();
        }
    
        /**
         * 如果调用者未初始化数据集,则采用内部测试数据集
         */
        private void initDataSet() {
            dataSet = new ArrayList<float[]>();
            // 其中{6,3}是一样的,所以长度为15的数据集分成14簇和15簇的误差都为0
            float[][] dataSetArray = new float[][] { { 8, 2 }, { 3, 4 }, { 2, 5 },
                    { 4, 2 }, { 7, 3 }, { 6, 2 }, { 4, 7 }, { 6, 3 }, { 5, 3 },
                    { 6, 3 }, { 6, 9 }, { 1, 6 }, { 3, 9 }, { 4, 1 }, { 8, 6 } };
    
            for (int i = 0; i < dataSetArray.length; i++) {
                dataSet.add(dataSetArray[i]);
            }
        }
    
        /**
         * 初始化中心数据链表,分成多少簇就有多少个中心点
         *
         * @return 中心点集
         */
        private ArrayList<float[]> initCenters() {
            ArrayList<float[]> center = new ArrayList<float[]>();
            int[] randoms = new int[k];
            boolean flag;
            int temp = random.nextInt(dataSetLength);
            randoms[0] = temp;
            for (int i = 1; i < k; i++) {
                flag = true;
                while (flag) {
                    temp = random.nextInt(dataSetLength);
                    int j = 0;
                    while (j < i) {
                        if (temp == randoms[j]) {
                            break;
                        }
                        j++;
                    }
                    if (j == i) {
                        flag = false;
                    }
                }
                randoms[i] = temp;
            }
            for (int i = 0; i < k; i++) {
                center.add(dataSet.get(randoms[i]));// 生成初始化中心链表
            }
            return center;
        }
    
        /**
         * 初始化簇集合
         *
         * @return 一个分为k簇的空数据的簇集合
         */
        private ArrayList<ArrayList<float[]>> initCluster() {
            ArrayList<ArrayList<float[]>> cluster = new ArrayList<ArrayList<float[]>>();
            for (int i = 0; i < k; i++) {
                cluster.add(new ArrayList<float[]>());
            }
            return cluster;
        }
    
        /**
         * 计算两个点之间的距离
         *
         * @param element
         *            点1
         * @param center
         *            点2
         * @return 距离
         */
        private float distance(float[] element, float[] center) {
            float distance = 0.0f;
            float x = element[0] - center[0];
            float y = element[1] - center[1];
            float z = x * x + y * y;
            distance = (float) Math.sqrt(z);
            return distance;
        }
    
        /**
         * 获取距离集合中最小距离的位置
         *
         * @param distance
         *            距离数组
         * @return 最小距离在距离数组中的位置
         */
        private int minDistance(float[] distance) {
            float minDistance = distance[0];
            int minLocation = 0;
            for (int i = 1; i < distance.length; i++) {
                if (distance[i] < minDistance) {
                    minDistance = distance[i];
                    minLocation = i;
                } else if (distance[i] == minDistance) // 如果相等,随机返回一个位置
                {
                    if (random.nextInt(10) < 5) {
                        minLocation = i;
                    }
                }
            }
            return minLocation;
        }
    
        /**
         * 核心,将当前元素放到最小距离中心相关的簇中
         */
        private void clusterSet() {
            float[] distance = new float[k];
            for (int i = 0; i < dataSetLength; i++) {
                for (int j = 0; j < k; j++) {
                    distance[j] = distance(dataSet.get(i), center.get(j));
                }
                int minLocation = minDistance(distance);
                cluster.get(minLocation).add(dataSet.get(i));// 核心,将当前元素放到最小距离中心相关的簇中
    
            }
        }
    
        /**
         * 求两点误差平方的方法
         *
         * @param element
         *            点1
         * @param center
         *            点2
         * @return 误差平方
         */
        private float errorSquare(float[] element, float[] center) {
            float x = element[0] - center[0];
            float y = element[1] - center[1];
            float errSquare = x * x + y * y;
            return errSquare;
        }
    
        /**
         * 计算误差平方和准则函数方法
         */
        private void countRule() {
            float jcF = 0;
            for (int i = 0; i < cluster.size(); i++) {
                for (int j = 0; j < cluster.get(i).size(); j++) {
                    jcF += errorSquare(cluster.get(i).get(j), center.get(i));
                }
            }
            jc.add(jcF);
        }
    
        /**
         * 设置新的簇中心方法
         */
        private void setNewCenter() {
            for (int i = 0; i < k; i++) {
                int n = cluster.get(i).size();
                if (n != 0) {
                    float[] newCenter = { 0, 0 };
                    for (int j = 0; j < n; j++) {
                        newCenter[0] += cluster.get(i).get(j)[0];
                        newCenter[1] += cluster.get(i).get(j)[1];
                    }
                    // 设置一个平均值
                    newCenter[0] = newCenter[0] / n;
                    newCenter[1] = newCenter[1] / n;
                    center.set(i, newCenter);
                }
            }
        }
    
        /**
         * 打印数据,测试用
         * @param dataArray
         *            数据集
         * @param dataArrayName
         *            数据集名称
         */
        public void printDataArray(ArrayList<float[]> dataArray, String dataArrayName) {
            for (int i = 0; i < dataArray.size(); i++) {
                System.out.println("print:" + dataArrayName + "[" + i + "]={" + dataArray.get(i)[0] + "," + dataArray.get(i)[1] + "}");
            }
            System.out.println("===================================");
        }
    
        /**
         * Kmeans算法核心过程方法
         */
        private void kmeans() {
            init();
            // 循环分组,直到误差不变为止
            while (true) {
                clusterSet();
                countRule();
                // 误差不变了,分组完成
                if (m != 0) {
                    if (jc.get(m) - jc.get(m - 1) == 0) {
                        break;
                    }
                }
                setNewCenter();
                m++;
                cluster.clear();
                cluster = initCluster();
            }
        }
    
        /**
         * 执行算法
         */
        public void execute() {
            long startTime = System.currentTimeMillis();
            System.out.println("kmeans begins");
            kmeans();
            long endTime = System.currentTimeMillis();
            System.out.println("kmeans running time=" + (endTime - startTime) + "ms");
            System.out.println("kmeans ends");
            System.out.println();
        }
    
        public  static void main(String[] args){
            //初始化一个Kmean对象,将k置为10
            Kmeans2 k=new Kmeans2(4);
            ArrayList<float[]> dataSet=new ArrayList<float[]>();
    
            dataSet.add(new float[]{1,22});
            dataSet.add(new float[]{3,333});
            dataSet.add(new float[]{3,4});
            dataSet.add(new float[]{5,6});
            dataSet.add(new float[]{8,9999});
            dataSet.add(new float[]{4,5});
            dataSet.add(new float[]{6,4});
            dataSet.add(new float[]{3,95});
            dataSet.add(new float[]{5,9});
            dataSet.add(new float[]{4,7777});
            dataSet.add(new float[]{1,9});
            dataSet.add(new float[]{7,844});
            //设置原始数据集
            k.setDataSet(dataSet);
            //执行算法
            k.execute();
            //得到聚类结果
            ArrayList<ArrayList<float[]>> cluster=k.getCluster();
            //查看结果
            for(int i=0;i<cluster.size();i++){
                k.printDataArray(cluster.get(i), "cluster["+i+"]");
            }
    
        }
    
    }
    
    
    image.gif

    大致思想不变:就是不断选择聚类中心,根据距离选择加入簇的节点,不断迭代,直到距离小于某个阀值或不变则聚类结束。

    相关文章

      网友评论

          本文标题:Kmeans算法

          本文链接:https://www.haomeiwen.com/subject/lggdwftx.html