上一段实习的时候用spark手写了一个tfidf,下面贴上代码并和spark中的源码进行比较。
输入文本(demo):
文档1:a b c d e f g
文档2:a b c d e f
文档3:a b c d e
文档4:a b c d
文档5:a b c
文档6:a b
文档7:a
输出结果:
代码分析
主要有以下几个步骤:
- 读取文件到JavaRDD<String>中
- mapToPair将每行文本映射为doc <标题 : 单词[]>中,后者为分词后的单词数组
- mapValues获取每个文档的词频
- 将文档数进行广播,用于计算idf
- 类似于wordCount, 先将doc中的每个文本对应的去重单词出现次数置为1,然后aggregateByKey统计每个单词出现的文档数,用对应的求idf的公式,就可以求出idf了
- 将表示每个词idf的RDD<map> collect到driver,再进行广播,进行每个文档的tfIdf计算
- 最后写入输出文件
和spark Mllib中tf-idf实现方法的对比
源码中也是将tf计算和idf计算分隔开的,tf计算时也是用了HashMap但是使用了hash函数(hashcode取余numfeatures)将词映射到了一个int作为Key.在计算idf时每个文档使用了一个词语大小的向量来保存每个词是否出现过,累加这些向量就得到了整个数据集中每个词语出现的文档数,即IDF,再利用公式计算,不过源码中使用的是log即以e为底而不是以10为底。
源码中也是用广播的形式将TF和IDF联系起来
public class GenerateTags {
public static void main(String[] args) throws IOException{
SparkConf conf = new SparkConf().setMaster("local").setAppName("test");
// SparkConf conf = new SparkConf().setAppName("video-tags");
JavaSparkContext sc = new JavaSparkContext(conf);
System.setProperty("hadoop.home.dir", "D:\\winutils");
JavaRDD<String> lines = sc.textFile("C:\\Users\\YANGXIN\\Desktop\\test.txt");
//得到每个文档标题和对应的词串
JavaPairRDD<String, String[]> docs = lines.mapToPair(new PairFunction<String, String, String[]>() {
@Override
public Tuple2<String, String[]> call(String s) throws Exception {
String[] doc = s.split(":");
String title = doc[0];
String[] words = doc[1].split(" ");
return new Tuple2<String, String[]>(title, words);
}
});
//得到每个文档的词频
JavaPairRDD<String, Map<String, Double>> docTF = docs.mapValues(new Function<String[], Map<String, Double>>() {
@Override
public Map<String, Double> call(String[] strings) throws Exception {
Map<String, Double> map = new HashMap<String, Double>();
int sum = strings.length;
for(String str : strings){
double cnt = map.containsKey(str) ? map.get(str) : 1;
map.put(str, cnt);
}
for(String str : map.keySet()){
map.replace(str, map.get(str) / sum);
}
return map;
}
});
//文档数
final Broadcast<Long> docCnt = sc.broadcast(docs.count());
//得到每个词的idf值
JavaPairRDD<String, Integer> ones = docs.flatMapToPair(new PairFlatMapFunction<Tuple2<String, String[]>, String, Integer>() {
@Override
public Iterable<Tuple2<String, Integer>> call(Tuple2<String, String[]> stringTuple2) throws Exception {
List<Tuple2<String, Integer>> list = new ArrayList<Tuple2<String, Integer>>();
Set<String> set = new HashSet<String>();
for(String str : stringTuple2._2()){
set.add(str);
}
for(String str : set){
list.add(new Tuple2<>(str, 1));
}
return list;
}
});
//每个单词在多少个文档中出现了
JavaPairRDD<String, Integer> wordDocCnt= ones.aggregateByKey(0, new Function2<Integer, Integer, Integer>() {
@Override
public Integer call(Integer integer, Integer integer2) throws Exception { //同partition下的处理
return integer + integer2;
}
}, new Function2<Integer, Integer, Integer>() {
@Override
public Integer call(Integer integer, Integer integer2) throws Exception { //不同partition下的处理
return integer + integer2;
}
});
JavaPairRDD<String, Double> wordIdf = wordDocCnt.mapValues(new Function<Integer, Double>() {
@Override
public Double call(Integer integer) throws Exception {
return Math.log10((docCnt.getValue() + 1) * 1.0 / (integer + 1)); //计算逆文档频率
}
});
//广播idf值,进行tf-idf计算
Map<String, Double> idfs = wordIdf.collectAsMap();
final Broadcast<Map<String, Double>> idfMap = sc.broadcast(idfs);
//计算每个文档的tf-idf向量
JavaPairRDD<String, TreeMap<Double, String>> TfIdf = docTF.mapValues(new Function<Map<String, Double>, TreeMap<Double, String>>() {
@Override
public TreeMap<Double, String> call(Map<String, Double> stringDoubleMap) throws Exception {
TreeMap<Double, String> map = new TreeMap<Double, String>();
for(Map.Entry<String, Double> entry : stringDoubleMap.entrySet()){
String word = entry.getKey();
Double tf = entry.getValue();
Double idf = idfMap.getValue().get(word);
map.put(tf * idf, word);
}
return map;
}
});
TfIdf.saveAsTextFile("C:\\Users\\YANGXIN\\Desktop\\result");
}
网友评论