美文网首页
Spark -- 基于RDD实现 KNN

Spark -- 基于RDD实现 KNN

作者: k_wzzc | 来源:发表于2018-11-21 22:13 被阅读0次

    Spark -- 基于RDD实现 KNN

    上一篇 基于DataFrame实现KNN的过程中,由于中间使用了笛卡尔积,以及大规模的排序,对于运算的性能有较大影响,经过一定的调整,笔者找到一个相对较好的实现方法

      def runKnn(trainSet: DataFrame, testSet: DataFrame, k: Int, cl: String) = {
    
        val testFetures: RDD[Seq[Double]] = testSet
          .drop(cl).map(row => {
          val fetuers: Seq[Double] = row.mkString(",").split(",").map(_.toDouble)
          fetuers
        }).rdd
    
        val trainFetures: RDD[(String, Seq[Double])] = trainSet.map(row => {
          val cla = row.getAs[String](cl)
          val fetuers: Seq[Double] = row.mkString(",")
            .split(",").filter(NumberUtils.isNumber(_)).map(_.toDouble)
          (cla, fetuers)
        }).rdd
    
        // 将训练集广播
        val trainBroad = spark.sparkContext.broadcast(trainFetures.collect())
    
        val resRDD: RDD[Row] = testFetures.map(testTp => {
          //定义一个TreeSet之前 先自定义一个排序规则
          val orderRules: Ordering[(String, Double)] = Ordering.fromLessThan[(String, Double)](_._2 <= _._2)
          //新建一个空的set 传入排序规则
          var set: mutable.TreeSet[(String, Double)] = mutable.TreeSet.empty(orderRules)
    
          trainBroad.value.foreach(trainTp => {
            val dist = distance.Euclidean(testTp, trainTp._2)
            set += (trainTp._1 -> dist)
            // 设定了set的大小,排序的时候更高效
            if (set.size > k) set = set.slice(0, k) else set
          })
    
          // 获取 投票数最多的类  (一个Wordcount)
          val cla = set.toArray.groupBy(_._1)
            .map(t => (t._1, t._2.length)).maxBy(_._2)._1
    
          Row.merge(Row.fromSeq(testTp), Row(cla))
    
        })
    
        spark.createDataFrame(resRDD, trainSet.schema)
    
      }
    

    算法测试

    val iris = spark.read
          .option("header", true)
          .option("inferSchema", true)
          .csv(inputFile)
    
       // 将鸢尾花分成两部分:训练集和测试集
        val Array(testSet, trainSet) = iris.randomSplit(Array(0.3, 0.7), 1234L)
    
         val knnMode2 = new KNNRunner(spark)
        val res2 = knnMode2.runKnn(trainSet, testSet, 10, "class")
        
        
        res2.show(truncate = false)
        val check = udf((f1: String, f2: String) => {
          if (f1.equals(f2)) 1 else 0
        })
     
        res2.join(testSet.withColumnRenamed("class", "yclass"),
          Seq("sepalLength", "sepalWidth", "petalLength", "petalWidth"))
          .withColumn("check", check($"class", $"yclass"))
           .groupBy("check").count().show()
     
    
    +-----------+----------+-----------+----------+---------------+
    |sepalLength|sepalWidth|petalLength|petalWidth|class          |
    +-----------+----------+-----------+----------+---------------+
    |4.6        |3.2       |1.4        |0.2       |Iris-setosa    |
    |4.8        |3.0       |1.4        |0.1       |Iris-setosa    |
    |4.8        |3.4       |1.6        |0.2       |Iris-setosa    |
    
    +-----+-----+
    |check|count|
    +-----+-----+
    |    1|   53|
    |    0|    2|
    +-----+-----+
    

    从结果看,两个实现过程是一致的,但是本文使用的方法更高效。

    相关文章

      网友评论

          本文标题:Spark -- 基于RDD实现 KNN

          本文链接:https://www.haomeiwen.com/subject/uwbhqqtx.html