美文网首页
lightLDA输出接口-java版本

lightLDA输出接口-java版本

作者: yxwithu | 来源:发表于2017-08-15 14:47 被阅读0次

    根据LightLDA的输出文件得到文档-主题分布主题-词分布以及表示某篇文档的topN关键词

    import java.io.BufferedWriter;
    import java.io.FileReader;
    import java.io.FileWriter;
    import java.util.List;
    import java.util.PriorityQueue;
    
    /**
     * Created by yangxin on 2017/8/11.
     */
    public class LDAResult {
        private double alpha;  //主题分布Dirichlet分布参数
        private double beta;   //词分布Dirichlet分布参数
        private int topic_num;  //主题数目
        private int vocab_num;  //词数目
        private int doc_num;    //文档数目
        private double[][] doc_topic_mat = null;  //文档_主题概率矩阵
        private double[][] topic_vocab_mat = null; //主题_词概率矩阵
        private Item[][] doc_word_info = null;   //文档_top词的信息矩阵
    
        /**
         * lda每个doc对应的前n个词Id
         */
        public static class Item implements Comparable<Item>{
            public double prob;
            public int word_id;
    
            public Item(double prob, int word_id) {
                this.prob = prob;
                this.word_id = word_id;
            }
    
            @Override
            public String toString() {
                return "Item{" +
                        "prob=" + prob +
                        ", word_id=" + word_id +
                        '}';
            }
    
            @Override
            public int compareTo(Item o) {
                return prob - o.prob > 0 ? 1 : -1;
            }
        }
    
        public LDAResult(double alpha, double beta, int topic_num, int vocab_num, int doc_num) {
            this.alpha = alpha;
            this.beta = beta;
            this.topic_num = topic_num;
            this.vocab_num = vocab_num;
            this.doc_num = doc_num;
    
            doc_topic_mat = new double[topic_num][doc_num];
            topic_vocab_mat = new double[vocab_num][topic_num];
        }
    
        /**
         * 得到每个文档前n个关键词
         * @param n
         * @return
         */
        public Item[][] getDocTopWordInfo(int n){
            doc_word_info = new Item[doc_num][n];
            for(int i = 0; i < doc_num; ++i){ //每篇文档
                PriorityQueue<Item> queue = new PriorityQueue<>();
                for(int j = 0; j < vocab_num; ++j){ //每个词
                    double prob = 0;
                    for(int k = 0; k < topic_num; ++k){ //每个主题
                        prob += doc_topic_mat[k][i] * topic_vocab_mat[j][k];
                    }
                    Item item = new Item(prob, j);
                    queue.offer(item);
                    if(queue.size() > n){
                        queue.poll();
                    }
                }
                int q = queue.size();
                while(!queue.isEmpty()){
                    doc_word_info[i][--q] = queue.poll();
                }
            }
            return doc_word_info;
        }
    
        /**
         * 写每个文档的前n个关键词到文件中
         * @param n
         * @param output  输出文件
         * @param titles  doc标题列表
         * @param words   词列表
         * @throws Exception
         */
        public void dumpTopResult(int n, String output, final List<String> titles, final List<String> words) throws Exception{
            if(n <= 0) return;
            BufferedWriter bw = new BufferedWriter(new FileWriter(output));
            if(doc_word_info == null){
                doc_word_info = getDocTopWordInfo(n);
            }
    
            for(int i = 0; i < doc_num; ++i){  //doc_id
                bw.write(titles.get(i) + " : ");
                for(Item item : doc_word_info[i]){
                    bw.write(words.get(item.word_id) + "/" + item.prob + " ");
                }
                bw.newLine();
                bw.flush();
            }
    
            bw.close();
        }
    
        /**
         * 加载文档_主题模型
         * @param model_path
         * @throws Exception
         */
        public void loadDocTopicModel(String model_path) throws Exception{
            //将计数写入到矩阵中
            BufferedReader br = new BufferedReader(new FileReader(model_path));
            String line = null;
            while((line = br.readLine()) != null){
                String[] doc_info = line.split("[\t ]+");
                int doc_id = Integer.parseInt(doc_info[0]);  //文档号,从0开始
    
                for(int i = 1; i < doc_info.length; ++i){
                    String[] topic_info = doc_info[i].split(":");   //对应的主题信息
                    int topic_id = Integer.parseInt(topic_info[0]);  //主题id
                    int topic_cnt = Integer.parseInt(topic_info[1]);  //主题次数
                    doc_topic_mat[topic_id][doc_id] = topic_cnt;
                }
            }
            br.close();
    
            //计数
            int[] doc_cnts = new int[doc_num];  //每个文档对应的主题数量和,即包含词的数目
            for(int i = 0; i < doc_num; ++i){  //对每个文档
                for(int j = 0; j < topic_num; ++j){  //对每个主题
                    doc_cnts[i] += doc_topic_mat[j][i];
                }
            }
    
            //计算概率
            double factor = topic_num * alpha;
            for(int i = 0; i < doc_num; ++i){  //对每个文档
                for(int j = 0; j < topic_num; ++j){  //对每个主题
                    doc_topic_mat[j][i] = (doc_topic_mat[j][i] + alpha) / (doc_cnts[i] + factor);
                }
            }
        }
    
        /**
         * 加载主题_词模型
         * @param model_path  主题_词模型位置,对应文件 server_model_0
         * @param model_summary_path   主题数目统计,对应文件 server_model_1
         * @throws Exception
         */
        public void loadTopicWordModel(String model_path, String model_summary_path) throws Exception{
            //将计数写入到矩阵中
            BufferedReader br = new BufferedReader(new FileReader(model_path));
            String line = null;
            while((line = br.readLine()) != null){
                String[] info = line.split(" ");
                int word_id = Integer.parseInt(info[0]);  //词id
                for(int i = 1; i < info.length; ++i){
                    String[] topic_info = info[i].split(":"); //对应的每个topic信息
                    int topic_id = Integer.parseInt(topic_info[0]);  //topic id
                    int topic_cnt = Integer.parseInt(topic_info[1]);  //topic计数
                    topic_vocab_mat[word_id][topic_id] = topic_cnt;
                }
            }
            br.close();
    
            //写每个主题出现的次数
            int[] topic_cnts = new int[topic_num];   //主题出现的次数
            br = new BufferedReader(new FileReader(model_summary_path));
            String[] cnts = br.readLine().split(" ");
            for(int i = 1; i < cnts.length; ++i){
                String[] cnt_info = cnts[i].split(":");
                int topic_id = Integer.parseInt(cnt_info[0]);
                int topic_cnt = Integer.parseInt(cnt_info[1]);
                topic_cnts[topic_id] = topic_cnt;
            }
            br.close();
    
            //写概率
            double factor = vocab_num * beta;   //归一化因子
            for(int i = 0; i < vocab_num; ++i){  //每个词
                for(int j = 0; j < topic_num; ++j){  //每个主题
                    topic_vocab_mat[i][j] = (topic_vocab_mat[i][j] + beta) / (topic_cnts[j] + factor);
                }
            }
        }
    }
    

    调用

    public static void main(String[] args) throws Exception{
        String doc_topic_path = "doc_topic.0";
        String topic_word_path = "server_0_table_0.model";
        String topic_summary = "server_0_table_1.model";
        String ori_doc_path = "merge_texts";
        String ori_word_path = "vocab";
        String output = "result";
        LDAResult result = new LDAResult(0.22, 0.1, 220, 1539967, 146119);
        result.loadTopicWordModel(topic_word_path, topic_summary);  //得到主题-词概率分布
        result.loadDocTopicModel(doc_topic_path);   //得到文档-主题概率分布
        List<String> titles = Util.getTitles(ori_doc_path);  //所有文档名
        List<String> words = Util.getVocabs(ori_word_path);  //所有词
        result.dumpTopResult(10, output, titles, words);  //每篇文档的前10个关键词写入到output中
    }
    

    相关文章

      网友评论

          本文标题:lightLDA输出接口-java版本

          本文链接:https://www.haomeiwen.com/subject/nvamrxtx.html