美文网首页机器学习与数据挖掘
在Weka中实现流形学习Isomap中的距离计算

在Weka中实现流形学习Isomap中的距离计算

作者: Daisy丶 | 来源:发表于2017-03-19 15:25 被阅读161次

    最近因为项目需求,需要时在weka上实现流形距离计算,因为weka没有提供流形学习的包,而smile提供了,于是我根据smile的等距离度量(Isomap)来重写了一个可在weka上使用的流形距离计算类。

    欧式距离是最常用的距离度量,但是在数据集不具有全局线性结构是,欧氏距离就不是一种合理的数据距离度量,一般使用拓扑流形结构来度量高维度的非线线性数据。这种方法通常用了对数据进行降维,也被称为流形学习。

    定义1:
    流形两点间x1, x2的线段长度定义为 L(x1, x2) = exp(d(x1, x2) / σ) -1
    定义2:
    将数据点看作是无向有权图G=(V, E),V是顶点集合,E是边集P的集合,Pij表示图上数据点Xi, Xj的所有路径集合,则Xi,Xj的流形距离为 MD(xi, xj)=min∑L(pk, pk+1), 1≤k≤|p| - 1

    算法流程:

    for i = 1,2,3...m do
        确定xi的k个最近邻
        将xi与k个最近邻的距离设为定义的距离公式,与自己的距离设为0,与其他点距离设为-1
        将这些数值添加进入邻接矩阵
    end
    
    根据邻接矩阵构建一个有权无向图的对象
    使用dijkstra最短距离求出图上任意两点的最短距离
    
    

    ManifoldDistance.java

    import weka.core.EuclideanDistance;
    import weka.core.Instances;
    
    import java.util.*;
    
    /**
     * Created by Administrator on 2017/3/15.
     */
    public class ManifoldDistance {
        private final Instances data;
        private final int k;
        private final double sigma;
        private double[][] matrix;
        private Graph graph = new Graph();
    
        /**
         * 流形学习的距离计算类的构造方法
         *
         * @param data  要计算的instances类型的数据集
         * @param k     KNN需要指定的参数k
         * @param sigma     距离公式需要的参数σ
         */
        public ManifoldDistance(Instances data, int k, double sigma) {
            this.data = data;
            this.k = k;
            this.sigma = sigma;
        }
    
        public Instances getData() {
            return data;
        }
    
        public int getK() {
            return k;
        }
    
        public double getSigma() {
            return sigma;
        }
    
        public double[][] getMatrix() {
            return matrix;
        }
    
        /**
         * 构造数据data的邻接矩阵
         *
         * @return      double[][]类型的邻接矩阵
         */
        private double[][] constructWeightMatrix() {
            int num = this.data.numInstances();
            double[][] weight_matrix = new double[num][num];
            EuclideanDistance calculateDistance = new EuclideanDistance(this.data);
    
            for(int i = 0; i < num; i++){
                HashMap<Integer, Double> temp = new HashMap<>();
                for(int j = 0; j < num; j++){
                    if(i != j) {
                        double dist = calculateDistance.distance(this.data.instance(i), this.data.instance(j));
                        temp.put(j, Math.exp(dist / this.sigma) - 1);
                    }else{
                        temp.put(j, 0.0);
                    }
                }
    
                ArrayList<Integer> index = nearestNeighbor(temp);
                for(int n = 0; n < num; n++){
                    if(index.contains(n)){
                        weight_matrix[i][n] = temp.get(n);
                        weight_matrix[n][i] = temp.get(n);
                    }else if(i == n){
                        weight_matrix[i][i] = 0.0;
                    }else{
                        if(weight_matrix[i][n] == 0.0) {
                            weight_matrix[i][n] = -1.0;
                        }
                    }
                }
            }
            return weight_matrix;
        }
    
        /**
         * 计算K个最近邻
         *
         * @param temp  当前向量i与其他所有向量的距离
         * @return      k个最近邻所在的位置索引
         */
        private ArrayList<Integer> nearestNeighbor(HashMap<Integer, Double> temp){
            ArrayList<Integer> index = new ArrayList<>();
            ArrayList<Map.Entry<Integer, Double>> list = new ArrayList<>(temp.entrySet());
            list.sort((o1, o2) -> o2.getValue().compareTo(o1.getValue()));
    
            int count = 0;
            for (Map.Entry<Integer, Double> aList : list) {
                if(count >= this.k){
                    break;
                }else {
                    index.add(aList.getKey());
                    count++;
                }
            }
            return index;
        }
    
        /**
         * 生成邻接矩阵与对应的无向有权图
         */
        public void build(){
            this.matrix = constructWeightMatrix();
    
            int num = this.matrix.length;
    
            HashMap<String, List<Vertex>>edge = new HashMap<>();
            for (int i = 0; i < num; i++){
                edge.put(Integer.toString(i), new ArrayList<>());
            }
    
            for (int i = 0; i < num; i++){
                for (int j = 0; j < num; j++){
                    if (this.matrix[i][j] > 0){
                        List<Vertex> iedge = edge.get(Integer.toString(i));
                        iedge.add(new Vertex(Integer.toString(j), this.matrix[i][j]));
                        edge.put(Integer.toString(i), iedge);
    
                        List<Vertex> jedge = edge.get(Integer.toString(j));
                        jedge.add(new Vertex(Integer.toString(i), this.matrix[i][j]));
                        edge.put(Integer.toString(j), jedge);
                    }
                }
            }
    
            for(String i : edge.keySet()){
                List<Vertex> toVertex = edge.get(i);
                this.graph.addVertex(i, toVertex);
            }
        }
    
        /**
         * 获取图上两个向量的dijkstra最短距离
         *
         * @param start     起始点
         * @param end   结束点
         * @return      最短距离的数值
         */
        public double getDistance(String start, String end){
            List<String> path = this.graph.getShortestPath(start, end);
            path.add(start);
            Collections.reverse(path);
    
            double mDist = 0.0;
            for (int i = 0; i < path.size() - 1; i++){
                int m = Integer.parseInt(path.get(i));
                int n = Integer.parseInt(path.get(i + 1));
                mDist += this.matrix[m][n];
            }
    
            System.out.println("shortest path:" + path);
            return mDist;
        }
    }
    
    

    Graph.java

    import java.util.*;
    
    /**
     * Created by Administrator on 2017/3/14.
     */
    
    class Graph {
    
        private final Map<String, List<Vertex>> vertices;
    
        public Graph() {
            this.vertices = new HashMap<>();
        }
    
        public void addVertex(String character, List<Vertex> vertex) {
            this.vertices.put(character, vertex);
        }
    
        public List<String> getShortestPath(String start, String finish) {
            final Map<String, Double> distances = new HashMap<>();
            final Map<String, Vertex> previous = new HashMap<>();
            PriorityQueue<Vertex> nodes = new PriorityQueue<>();
    
            for(String vertex : vertices.keySet()) {
                if (Objects.equals(vertex, start)) {
                    distances.put(vertex, 0.0);
                    nodes.add(new Vertex(vertex, 0.0));
                } else {
                    distances.put(vertex, Double.MAX_VALUE);
                    nodes.add(new Vertex(vertex, Double.MAX_VALUE));
                }
                previous.put(vertex, null);
            }
    
            while (!nodes.isEmpty()) {
                Vertex smallest = nodes.poll();
                if (Objects.equals(smallest.getId(), finish)) {
                    final List<String> path = new ArrayList<>();
                    while (previous.get(smallest.getId()) != null) {
                        path.add(smallest.getId());
                        smallest = previous.get(smallest.getId());
                    }
                    return path;
                }
    
                if (distances.get(smallest.getId()) == Integer.MAX_VALUE) {
                    break;
                }
    
                for (Vertex neighbor : vertices.get(smallest.getId())) {
                    Double alt = distances.get(smallest.getId()) + neighbor.getDistance();
                    if (alt < distances.get(neighbor.getId())) {
                        distances.put(neighbor.getId(), alt);
                        previous.put(neighbor.getId(), smallest);
    
                        for(Vertex n : nodes) {
                            if (Objects.equals(n.getId(), neighbor.getId())) {
                                nodes.remove(n);
                                n.setDistance(alt);
                                nodes.add(n);
                                break;
                            }
                        }
                    }
                }
            }
            return new ArrayList<>(distances.keySet());
        }
    }
    

    Vertex.java

    /**
     * Created by Administrator on 2017/3/14.
     */
    
    class Vertex implements Comparable<Vertex> {
    
        private String id;
        private Double distance;
    
        public Vertex(String id, Double distance) {
            super();
            this.id = id;
            this.distance = distance;
        }
    
        public String getId() {
            return id;
        }
    
        public Double getDistance() {
            return distance;
        }
    
        public void setId(String id) {
            this.id = id;
        }
    
        public void setDistance(Double distance) {
            this.distance = distance;
        }
    
        @Override
        public int hashCode() {
            final int prime = 31;
            int result = 1;
            result = prime * result
                    + ((distance == null) ? 0 : distance.hashCode());
            result = prime * result + ((id == null) ? 0 : id.hashCode());
            return result;
        }
    
        @Override
        public boolean equals(Object obj) {
            if (this == obj)
                return true;
            if (obj == null)
                return false;
            if (getClass() != obj.getClass())
                return false;
            Vertex other = (Vertex) obj;
            if (distance == null) {
                if (other.distance != null)
                    return false;
            } else if (!distance.equals(other.distance))
                return false;
            if (id == null) {
                if (other.id != null)
                    return false;
            } else if (!id.equals(other.id))
                return false;
            return true;
        }
    
        @Override
        public String toString() {
            return "Vertex [id=" + id + ", distance=" + distance + "]";
        }
    
        @Override
        public int compareTo(Vertex o) {
            if (this.distance < o.distance)
                return -1;
            else if (this.distance > o.distance)
                return 1;
            else
                return this.getId().compareTo(o.getId());
        }
    
    }
    
    

    Demo.java

    import weka.core.Instances;
    
    import java.io.FileReader;
    import java.io.IOException;
    
    /**
     * Created by Administrator on 2017/3/15.
     */
    public class Demo {
        public static void main(String[] args) throws IOException {
            Instances data = new Instances(new FileReader("Test/Manifold/cpu.arff"));
            ManifoldDistance manifold = new ManifoldDistance(data, 20, 2);
            manifold.build();
            for (double[] aMtx : manifold.getMatrix()) {
                for(double v : aMtx){
                    System.out.print(v + "   ");
                }
                System.out.println();
            }
    
            System.out.println(manifold.getDistance("10", "71"));
            System.out.println(manifold.getDistance("71", "10"));
        }
    }
    

    相关文章

      网友评论

        本文标题:在Weka中实现流形学习Isomap中的距离计算

        本文链接:https://www.haomeiwen.com/subject/axmunttx.html