根据LightLDA的输出文件得到文档-主题分布和主题-词分布以及表示某篇文档的topN关键词。
import java.io.BufferedWriter;
import java.io.FileReader;
import java.io.FileWriter;
import java.util.List;
import java.util.PriorityQueue;
/**
* Created by yangxin on 2017/8/11.
*/
public class LDAResult {
private double alpha; //主题分布Dirichlet分布参数
private double beta; //词分布Dirichlet分布参数
private int topic_num; //主题数目
private int vocab_num; //词数目
private int doc_num; //文档数目
private double[][] doc_topic_mat = null; //文档_主题概率矩阵
private double[][] topic_vocab_mat = null; //主题_词概率矩阵
private Item[][] doc_word_info = null; //文档_top词的信息矩阵
/**
* lda每个doc对应的前n个词Id
*/
public static class Item implements Comparable<Item>{
public double prob;
public int word_id;
public Item(double prob, int word_id) {
this.prob = prob;
this.word_id = word_id;
}
@Override
public String toString() {
return "Item{" +
"prob=" + prob +
", word_id=" + word_id +
'}';
}
@Override
public int compareTo(Item o) {
return prob - o.prob > 0 ? 1 : -1;
}
}
public LDAResult(double alpha, double beta, int topic_num, int vocab_num, int doc_num) {
this.alpha = alpha;
this.beta = beta;
this.topic_num = topic_num;
this.vocab_num = vocab_num;
this.doc_num = doc_num;
doc_topic_mat = new double[topic_num][doc_num];
topic_vocab_mat = new double[vocab_num][topic_num];
}
/**
* 得到每个文档前n个关键词
* @param n
* @return
*/
public Item[][] getDocTopWordInfo(int n){
doc_word_info = new Item[doc_num][n];
for(int i = 0; i < doc_num; ++i){ //每篇文档
PriorityQueue<Item> queue = new PriorityQueue<>();
for(int j = 0; j < vocab_num; ++j){ //每个词
double prob = 0;
for(int k = 0; k < topic_num; ++k){ //每个主题
prob += doc_topic_mat[k][i] * topic_vocab_mat[j][k];
}
Item item = new Item(prob, j);
queue.offer(item);
if(queue.size() > n){
queue.poll();
}
}
int q = queue.size();
while(!queue.isEmpty()){
doc_word_info[i][--q] = queue.poll();
}
}
return doc_word_info;
}
/**
* 写每个文档的前n个关键词到文件中
* @param n
* @param output 输出文件
* @param titles doc标题列表
* @param words 词列表
* @throws Exception
*/
public void dumpTopResult(int n, String output, final List<String> titles, final List<String> words) throws Exception{
if(n <= 0) return;
BufferedWriter bw = new BufferedWriter(new FileWriter(output));
if(doc_word_info == null){
doc_word_info = getDocTopWordInfo(n);
}
for(int i = 0; i < doc_num; ++i){ //doc_id
bw.write(titles.get(i) + " : ");
for(Item item : doc_word_info[i]){
bw.write(words.get(item.word_id) + "/" + item.prob + " ");
}
bw.newLine();
bw.flush();
}
bw.close();
}
/**
* 加载文档_主题模型
* @param model_path
* @throws Exception
*/
public void loadDocTopicModel(String model_path) throws Exception{
//将计数写入到矩阵中
BufferedReader br = new BufferedReader(new FileReader(model_path));
String line = null;
while((line = br.readLine()) != null){
String[] doc_info = line.split("[\t ]+");
int doc_id = Integer.parseInt(doc_info[0]); //文档号,从0开始
for(int i = 1; i < doc_info.length; ++i){
String[] topic_info = doc_info[i].split(":"); //对应的主题信息
int topic_id = Integer.parseInt(topic_info[0]); //主题id
int topic_cnt = Integer.parseInt(topic_info[1]); //主题次数
doc_topic_mat[topic_id][doc_id] = topic_cnt;
}
}
br.close();
//计数
int[] doc_cnts = new int[doc_num]; //每个文档对应的主题数量和,即包含词的数目
for(int i = 0; i < doc_num; ++i){ //对每个文档
for(int j = 0; j < topic_num; ++j){ //对每个主题
doc_cnts[i] += doc_topic_mat[j][i];
}
}
//计算概率
double factor = topic_num * alpha;
for(int i = 0; i < doc_num; ++i){ //对每个文档
for(int j = 0; j < topic_num; ++j){ //对每个主题
doc_topic_mat[j][i] = (doc_topic_mat[j][i] + alpha) / (doc_cnts[i] + factor);
}
}
}
/**
* 加载主题_词模型
* @param model_path 主题_词模型位置,对应文件 server_model_0
* @param model_summary_path 主题数目统计,对应文件 server_model_1
* @throws Exception
*/
public void loadTopicWordModel(String model_path, String model_summary_path) throws Exception{
//将计数写入到矩阵中
BufferedReader br = new BufferedReader(new FileReader(model_path));
String line = null;
while((line = br.readLine()) != null){
String[] info = line.split(" ");
int word_id = Integer.parseInt(info[0]); //词id
for(int i = 1; i < info.length; ++i){
String[] topic_info = info[i].split(":"); //对应的每个topic信息
int topic_id = Integer.parseInt(topic_info[0]); //topic id
int topic_cnt = Integer.parseInt(topic_info[1]); //topic计数
topic_vocab_mat[word_id][topic_id] = topic_cnt;
}
}
br.close();
//写每个主题出现的次数
int[] topic_cnts = new int[topic_num]; //主题出现的次数
br = new BufferedReader(new FileReader(model_summary_path));
String[] cnts = br.readLine().split(" ");
for(int i = 1; i < cnts.length; ++i){
String[] cnt_info = cnts[i].split(":");
int topic_id = Integer.parseInt(cnt_info[0]);
int topic_cnt = Integer.parseInt(cnt_info[1]);
topic_cnts[topic_id] = topic_cnt;
}
br.close();
//写概率
double factor = vocab_num * beta; //归一化因子
for(int i = 0; i < vocab_num; ++i){ //每个词
for(int j = 0; j < topic_num; ++j){ //每个主题
topic_vocab_mat[i][j] = (topic_vocab_mat[i][j] + beta) / (topic_cnts[j] + factor);
}
}
}
}
调用
public static void main(String[] args) throws Exception{
String doc_topic_path = "doc_topic.0";
String topic_word_path = "server_0_table_0.model";
String topic_summary = "server_0_table_1.model";
String ori_doc_path = "merge_texts";
String ori_word_path = "vocab";
String output = "result";
LDAResult result = new LDAResult(0.22, 0.1, 220, 1539967, 146119);
result.loadTopicWordModel(topic_word_path, topic_summary); //得到主题-词概率分布
result.loadDocTopicModel(doc_topic_path); //得到文档-主题概率分布
List<String> titles = Util.getTitles(ori_doc_path); //所有文档名
List<String> words = Util.getVocabs(ori_word_path); //所有词
result.dumpTopResult(10, output, titles, words); //每篇文档的前10个关键词写入到output中
}
网友评论