美文网首页
大数据量下寻找相邻单词的数量

大数据量下寻找相邻单词的数量

作者: 王知无 | 来源:发表于2020-02-02 18:29 被阅读0次

    这题目和Leetcode中的一些搜索题目有点类似。

    想处理的问题是:统计一个单词相邻前后两位的数量,如有w1,w2,w3,w4,w5,w6,则:

    file

    最终要输出为(word,neighbor,frequency)。

    我们用五种方法实现:

    • MapReduce
    • Spark
    • Spark SQL的方法
    • Scala方法
    • Scala版Spark SQL

    MapReduce

    file
    //map函数
     @Override
        protected void map(LongWritable key, Text value, Context context)
                throws IOException, InterruptedException {
    
            String[] tokens = StringUtils.split(value.toString(), " ");
            //String[] tokens = StringUtils.split(value.toString(), "\\s+");
            if ((tokens == null) || (tokens.length < 2)) {
                return;
            }
            //计算相邻两个单词的计算规则
            for (int i = 0; i < tokens.length; i++) {
                tokens[i] = tokens[i].replaceAll("\\W+", "");
    
                if (tokens[i].equals("")) {
                    continue;
                }
    
                pair.setWord(tokens[i]);
    
                int start = (i - neighborWindow < 0) ? 0 : i - neighborWindow;
                int end = (i + neighborWindow >= tokens.length) ? tokens.length - 1 : i + neighborWindow;
                for (int j = start; j <= end; j++) {
                    if (j == i) {
                        continue;
                    }
                    pair.setNeighbor(tokens[j].replaceAll("\\W", ""));
                    context.write(pair, ONE);
                }
                //
                pair.setNeighbor("*");
                totalCount.set(end - start);
                context.write(pair, totalCount);
            }
        }
    
    
    //reduce函数
     @Override
        protected void reduce(PairOfWords key, Iterable<IntWritable> values, Context context)
                throws IOException, InterruptedException {
            //等于*表示为单词本身,它的count为totalCount
            if (key.getNeighbor().equals("*")) {
                if (key.getWord().equals(currentWord)) {
                    totalCount += totalCount + getTotalCount(values);
                } else {
                    currentWord = key.getWord();
                    totalCount = getTotalCount(values);
                }
            } else {
                //其它的则为单次的word,需要通过getTotalCount获得相加
                int count = getTotalCount(values);
                relativeCount.set((double) count / totalCount);
                context.write(key, relativeCount);
            }
    
        }
    
    

    Spark

    public static void main(String[] args) {
            if (args.length < 3) {
                System.out.println("Usage: RelativeFrequencyJava <neighbor-window> <input-dir> <output-dir>");
                System.exit(1);
            }
    
            SparkConf sparkConf = new SparkConf().setAppName("RelativeFrequency");
            JavaSparkContext sc = new JavaSparkContext(sparkConf);
    
            int neighborWindow = Integer.parseInt(args[0]);
            String input = args[1];
            String output = args[2];
    
            final Broadcast<Integer> brodcastWindow = sc.broadcast(neighborWindow);
    
            JavaRDD<String> rawData = sc.textFile(input);
    
            /*
             * Transform the input to the format: (word, (neighbour, 1))
             */
            JavaPairRDD<String, Tuple2<String, Integer>> pairs = rawData.flatMapToPair(
                    new PairFlatMapFunction<String, String, Tuple2<String, Integer>>() {
                private static final long serialVersionUID = -6098905144106374491L;
    
                @Override
                public java.util.Iterator<scala.Tuple2<String, scala.Tuple2<String, Integer>>> call(String line) throws Exception {
                    List<Tuple2<String, Tuple2<String, Integer>>> list = new ArrayList<Tuple2<String, Tuple2<String, Integer>>>();
                    String[] tokens = line.split("\\s");
                    for (int i = 0; i < tokens.length; i++) {
                        int start = (i - brodcastWindow.value() < 0) ? 0 : i - brodcastWindow.value();
                        int end = (i + brodcastWindow.value() >= tokens.length) ? tokens.length - 1 : i + brodcastWindow.value();
                        for (int j = start; j <= end; j++) {
                            if (j != i) {
                                list.add(new Tuple2<String, Tuple2<String, Integer>>(tokens[i], new Tuple2<String, Integer>(tokens[j], 1)));
                            } else {
                                // do nothing
                                continue;
                            }
                        }
                    }
                    return list.iterator();
                }
            }
            );
    
            // (word, sum(word))
            //PairFunction<T, K, V> T => Tuple2<K, V>
            JavaPairRDD<String, Integer> totalByKey = pairs.mapToPair(
    
                    new PairFunction<Tuple2<String, Tuple2<String, Integer>>, String, Integer>() {
                private static final long serialVersionUID = -213550053743494205L;
    
                @Override
                public Tuple2<String, Integer> call(Tuple2<String, Tuple2<String, Integer>> tuple) throws Exception {
                    return new Tuple2<String, Integer>(tuple._1, tuple._2._2);
                }
            }).reduceByKey(
                            new Function2<Integer, Integer, Integer>() {
                        private static final long serialVersionUID = -2380022035302195793L;
    
                        @Override
                        public Integer call(Integer v1, Integer v2) throws Exception {
                            return (v1 + v2);
                        }
                    });
    
            JavaPairRDD<String, Iterable<Tuple2<String, Integer>>> grouped = pairs.groupByKey();
    
            // (word, (neighbour, 1)) -> (word, (neighbour, sum(neighbour)))
            //flatMapValues至少对value进行操作,但是不改变key的顺序
            JavaPairRDD<String, Tuple2<String, Integer>> uniquePairs = grouped.flatMapValues(
                    //Function<T1, R> -> R call(T1 v1)
                    new Function<Iterable<Tuple2<String, Integer>>, Iterable<Tuple2<String, Integer>>>() {
                private static final long serialVersionUID = 5790208031487657081L;
                
                @Override
                public Iterable<Tuple2<String, Integer>> call(Iterable<Tuple2<String, Integer>> values) throws Exception {
                    Map<String, Integer> map = new HashMap<>();
                    List<Tuple2<String, Integer>> list = new ArrayList<>();
                    Iterator<Tuple2<String, Integer>> iterator = values.iterator();
                    while (iterator.hasNext()) {
                        Tuple2<String, Integer> value = iterator.next();
                        int total = value._2;
                        if (map.containsKey(value._1)) {
                            total += map.get(value._1);
                        }
                        map.put(value._1, total);
                    }
                    for (Map.Entry<String, Integer> kv : map.entrySet()) {
                        list.add(new Tuple2<String, Integer>(kv.getKey(), kv.getValue()));
                    }
                    return list;
                }
            });
    
            // (word, ((neighbour, sum(neighbour)), sum(word)))
            JavaPairRDD<String, Tuple2<Tuple2<String, Integer>, Integer>> joined = uniquePairs.join(totalByKey);
    
            // ((key, neighbour), sum(neighbour)/sum(word))
            JavaPairRDD<Tuple2<String, String>, Double> relativeFrequency = joined.mapToPair(
                    new PairFunction<Tuple2<String, Tuple2<Tuple2<String, Integer>, Integer>>, Tuple2<String, String>, Double>() {
                private static final long serialVersionUID = 3870784537024717320L;
    
                @Override
                public Tuple2<Tuple2<String, String>, Double> call(Tuple2<String, Tuple2<Tuple2<String, Integer>, Integer>> tuple) throws Exception {
                    return new Tuple2<Tuple2<String, String>, Double>(new Tuple2<String, String>(tuple._1, tuple._2._1._1), ((double) tuple._2._1._2 / tuple._2._2));
                }
            });
    
            // For saving the output in tab separated format
            // ((key, neighbour), relative_frequency)
            //将结果转换成一个String
            JavaRDD<String> formatResult_tab_separated = relativeFrequency.map(
                    new Function<Tuple2<Tuple2<String, String>, Double>, String>() {
                private static final long serialVersionUID = 7312542139027147922L;
    
                @Override
                public String call(Tuple2<Tuple2<String, String>, Double> tuple) throws Exception {
                    return tuple._1._1 + "\t" + tuple._1._2 + "\t" + tuple._2;
                }
            });
    
            // save output
            formatResult_tab_separated.saveAsTextFile(output);
    
            // done
            sc.close();
    
        }
    
    

    Spark SQL

    
     public static void main(String[] args) {
            if (args.length < 3) {
                System.out.println("Usage: SparkSQLRelativeFrequency <neighbor-window> <input-dir> <output-dir>");
                System.exit(1);
            }
    
            SparkConf sparkConf = new SparkConf().setAppName("SparkSQLRelativeFrequency");
            //创建SparkSQL需要的SparkSession
            SparkSession spark = SparkSession
                    .builder()
                    .appName("SparkSQLRelativeFrequency")
                    .config(sparkConf)
                    .getOrCreate();
    
            JavaSparkContext sc = new JavaSparkContext(spark.sparkContext());
            int neighborWindow = Integer.parseInt(args[0]);
            String input = args[1];
            String output = args[2];
    
            final Broadcast<Integer> brodcastWindow = sc.broadcast(neighborWindow);
    
            /*
             *注册一个Schema表,这个frequency等会要用
             * Schema (word, neighbour, frequency)
             */
            StructType rfSchema = new StructType(new StructField[]{
                new StructField("word", DataTypes.StringType, false, Metadata.empty()),
                new StructField("neighbour", DataTypes.StringType, false, Metadata.empty()),
                new StructField("frequency", DataTypes.IntegerType, false, Metadata.empty())});
    
            JavaRDD<String> rawData = sc.textFile(input);
    
            /*
             * Transform the input to the format: (word, (neighbour, 1))
             */
            JavaRDD<Row> rowRDD = rawData
                    .flatMap(new FlatMapFunction<String, Row>() {
                        private static final long serialVersionUID = 5481855142090322683L;
    
                        @Override
                        public Iterator<Row> call(String line) throws Exception {
                            List<Row> list = new ArrayList<>();
                            String[] tokens = line.split("\\s");
                            for (int i = 0; i < tokens.length; i++) {
                                int start = (i - brodcastWindow.value() < 0) ? 0
                                        : i - brodcastWindow.value();
                                int end = (i + brodcastWindow.value() >= tokens.length) ? tokens.length - 1
                                        : i + brodcastWindow.value();
                                for (int j = start; j <= end; j++) {
                                    if (j != i) {
                                        list.add(RowFactory.create(tokens[i], tokens[j], 1));
                                    } else {
                                        // do nothing
                                        continue;
                                    }
                                }
                            }
                            return list.iterator();
                        }
                    });
            //创建DataFrame
            Dataset<Row> rfDataset = spark.createDataFrame(rowRDD, rfSchema);
            //将rfDataset转成一个table,可以进行查询
            rfDataset.createOrReplaceTempView("rfTable");
    
            String query = "SELECT a.word, a.neighbour, (a.feq_total/b.total) rf "
                    + "FROM (SELECT word, neighbour, SUM(frequency) feq_total FROM rfTable GROUP BY word, neighbour) a "
                    + "INNER JOIN (SELECT word, SUM(frequency) as total FROM rfTable GROUP BY word) b ON a.word = b.word";
            Dataset<Row> sqlResult = spark.sql(query);
    
            sqlResult.show(); // print first 20 records on the console
            sqlResult.write().parquet(output + "/parquetFormat"); // saves output in compressed Parquet format, recommended for large projects.
            sqlResult.rdd().saveAsTextFile(output + "/textFormat"); // to see output via cat command
    
            // done
            sc.close();
            spark.stop();
    
        }
    
    
    

    Scala

    def main(args: Array[String]): Unit = {
    
        if (args.size < 3) {
          println("Usage: RelativeFrequency <neighbor-window> <input-dir> <output-dir>")
          sys.exit(1)
        }
    
        val sparkConf = new SparkConf().setAppName("RelativeFrequency")
        val sc = new SparkContext(sparkConf)
    
        val neighborWindow = args(0).toInt
        val input = args(1)
        val output = args(2)
    
        val brodcastWindow = sc.broadcast(neighborWindow)
    
        val rawData = sc.textFile(input)
    
        /* 
         * Transform the input to the format:
         * (word, (neighbour, 1))
         */
        val pairs = rawData.flatMap(line => {
          val tokens = line.split("\\s")
          for {
            i <- 0 until tokens.length
            start = if (i - brodcastWindow.value < 0) 0 else i - brodcastWindow.value
            end = if (i + brodcastWindow.value >= tokens.length) tokens.length - 1 else i + brodcastWindow.value
            j <- start to end if (j != i)
            //用yield来收集转换之后的函数(word, (neighbour, 1))
          } yield (tokens(i), (tokens(j), 1))
        })
    
        // (word, sum(word))
        val totalByKey = pairs.map(t => (t._1, t._2._2)).reduceByKey(_ + _)
    
        val grouped = pairs.groupByKey()
    
        // (word, (neighbour, sum(neighbour)))
        val uniquePairs = grouped.flatMapValues(_.groupBy(_._1).mapValues(_.unzip._2.sum))
        //用join函数把两个RDD连接起来
        // (word, ((neighbour, sum(neighbour)), sum(word)))
        val joined = uniquePairs join totalByKey
    
        // ((key, neighbour), sum(neighbour)/sum(word))
        val relativeFrequency = joined.map(t => {
          ((t._1, t._2._1._1), (t._2._1._2.toDouble / t._2._2.toDouble))
        })
    
        // For saving the output in tab separated format
        // ((key, neighbour), relative_frequency)
        val formatResult_tab_separated = relativeFrequency.map(t => t._1._1 + "\t" + t._1._2 + "\t" + t._2)
        formatResult_tab_separated.saveAsTextFile(output)
    
        // done
        sc.stop()
      }
    
    

    Scala版Spark SQL

    def main(args: Array[String]): Unit = {
    
        if (args.size < 3) {
          println("Usage: SparkSQLRelativeFrequency <neighbor-window> <input-dir> <output-dir>")
          sys.exit(1)
        }
    
        val sparkConf = new SparkConf().setAppName("SparkSQLRelativeFrequency")
    
        val spark = SparkSession
          .builder()
          .config(sparkConf)
          .getOrCreate()
        val sc = spark.sparkContext
    
        val neighborWindow = args(0).toInt
        val input = args(1)
        val output = args(2)
    
        val brodcastWindow = sc.broadcast(neighborWindow)
    
        val rawData = sc.textFile(input)
    
        /*
        * Schema
        * (word, neighbour, frequency)
        */
        val rfSchema = StructType(Seq(
          StructField("word", StringType, false),
          StructField("neighbour", StringType, false),
          StructField("frequency", IntegerType, false)))
    
        /* 
         * Transform the input to the format:
         * Row(word, neighbour, 1)
         */
        //转换成StructType中要求的格式
        val rowRDD = rawData.flatMap(line => {
          val tokens = line.split("\\s")
          for {
            i <- 0 until tokens.length
            //正常的计算规则,与MapReduce有区别
            start = if (i - brodcastWindow.value < 0) 0 else i - brodcastWindow.value
            end = if (i + brodcastWindow.value >= tokens.length) tokens.length - 1 else i + brodcastWindow.value
            j <- start to end if (j != i)
          } yield Row(tokens(i), tokens(j), 1)
        })
    
        val rfDataFrame = spark.createDataFrame(rowRDD, rfSchema)
        //创建rfTable表
        rfDataFrame.createOrReplaceTempView("rfTable")
    
        import spark.sql
    
        val query = "SELECT a.word, a.neighbour, (a.feq_total/b.total) rf " +
          "FROM (SELECT word, neighbour, SUM(frequency) feq_total FROM rfTable GROUP BY word, neighbour) a " +
          "INNER JOIN (SELECT word, SUM(frequency) as total FROM rfTable GROUP BY word) b ON a.word = b.word"
    
        val sqlResult = sql(query)
        sqlResult.show() // print first 20 records on the console
        sqlResult.write.save(output + "/parquetFormat") // saves output in compressed Parquet format, recommended for large projects.
        sqlResult.rdd.saveAsTextFile(output + "/textFormat") // to see output via cat command
    
        // done
        spark.stop()
    
      }
    
    

    关注我的公众号,后台回复【JAVAPDF】获取200页面试题!
    5万人关注的大数据成神之路,不来了解一下吗?
    5万人关注的大数据成神之路,真的不来了解一下吗?
    5万人关注的大数据成神之路,确定真的不来了解一下吗?

    欢迎您关注《大数据成神之路》

    相关文章

      网友评论

          本文标题:大数据量下寻找相邻单词的数量

          本文链接:https://www.haomeiwen.com/subject/eiiqxhtx.html