美文网首页
Spark 实现优化的线性感知机算法:Pocket PLA

Spark 实现优化的线性感知机算法:Pocket PLA

作者: k_wzzc | 来源:发表于2019-04-21 20:10 被阅读0次

    Spark 实现优化的线性感知机算法:Pocket PLA

    普通感知机存在的问题

    上一篇文章中我们实现了普通的感知机算法,但是只能处理线性可分的是数据集,在训练非线性可分数据集时,结果会在一定范围产生震荡,就如下图所示:

    震荡效果
    图片来源:https://www.leiphone.com/news/201706/QFydbeV7FXQtRIOl.html

    优化:口袋算法(Pocket)

    要处理非线性可分的数据集,我们就需要对算法进行一定的优化,通常使用的优化算法就是口袋PLA(Pocket PLA),它是一种贪心算法,其基本思路是这样的:在寻找最优分类的过程中,不断地根据上一次迭代的结果进行修正,并将每一次计算的解放入一个“口袋”中,在经过有限次的迭代之后,我们从“口袋”中选择最优的一个解作为最终结果,这样得到的解可能是局部最优的。

    代码实现

    import breeze.linalg.{DenseVector => densevector}
    import org.apache.spark.ml.feature.VectorAssembler
    import org.apache.spark.ml.linalg.DenseVector
    import org.apache.spark.sql.expressions.Window
    import org.apache.spark.sql.SparkSession
    import org.apache.spark.sql.functions._
    import scala.collection.mutable.ListBuffer
    import scala.util.Random
    
    /**
      * Created by WZZC on 2019/3/15
      * 通用感知机模型
      **/
    object pocketPla {
      def main(args: Array[String]): Unit = {
    
        val spark = SparkSession
          .builder()
          .appName(s"${this.getClass.getSimpleName}")
          .master("local[*]")
          .getOrCreate()
    
        import spark.implicits._
    
      //  数据加载
        val data = spark.read
          .option("inferSchema", true)
          .option("header", true)
          .csv("F:\\DataSource\\pocketPla.csv")
    
        val schema = data.schema
        val fts = schema.filterNot(_.name == "lable").map(_.name).toArray
    
        val amountVectorAssembler: VectorAssembler = new VectorAssembler()
          .setInputCols(fts)
          .setOutputCol("features")
    
        val vec2Array = udf((vec: DenseVector) => vec.toArray)
    
        val dataFeatrus = amountVectorAssembler
          .transform(data)
          .select($"lable", vec2Array($"features") as "features")
          .cache()
    
        var initW: densevector[Double] = densevector.rand[Double](fts.length) //创建一个初始化的随机向量
        var initb: Double = Random.nextDouble()
        var flag = true
        val lrate = 0.1 // 学习率
        var iteration = 0 //迭代次数
    
        var countError = dataFeatrus.count() //初始化错判个数(取样本大小)
        var resW = initW
        var resB = initb
    
        // 定义判别函数
        val signudf = udf((t: Seq[Double], y: Double) => {
          val wx = initW.dot(densevector(t.toArray))
          val d = wx + initb
          val ny = if (d >= 0) 1 else -1
          ny
        })
    
        while (flag && iteration < 200) {
    
          val df = dataFeatrus.withColumn("sign", signudf($"features", $"lable"))
          val loss = df.where($"sign" =!= $"lable")
          val count = loss.count().toInt
    
         
           //  判断新模型的误判次数是否小于前一次的误判次数
           //  如果小于则更新权值向量和偏置,大于则不更新
        
          if (count < countError) {
            countError = count
            resW = initW
            resB = initb
          }
    
          println(s"迭代第${iteration}次 error:" + count)
    
          if (count == 0) {
            flag = false
          } else {
            // w1 = w0 + ny1x1
            //随机选择一个误判样本
            val rand = Random.nextInt(loss.count().toInt) + 1
    
            val randy = loss
              .withColumn("r", row_number().over(Window.orderBy($"lable")))
              .where($"r" === rand)
              .head()
    
            val y = randy.getAs[Int]("lable")
            initW = initW + densevector(
              randy.getAs[Seq[Double]]("features").toArray
            ).map(_ * y * lrate)
            // b1 = b0 + y
            initb = initb + y * lrate
    
          }
          iteration += 1
    
        }
    
        println(countError, resW, resB)
    
        // 定义判别函数
        val signudfres = udf((t: Seq[Double], y: Double) => {
          val wx = resW.dot(densevector(t.toArray))
          val d = wx + resB
          val ny = if (d >= 0) 1 else -1
          ny
        })
    
        val df = dataFeatrus.withColumn("sign", signudfres($"features", $"lable"))
    
        df.show(100)
    
        spark.stop()
      }
    }
    
    

    参考资料:

    https://www.leiphone.com/news/201706/QFydbeV7FXQtRIOl.html
    林轩田机器学习基石

    相关文章

      网友评论

          本文标题:Spark 实现优化的线性感知机算法:Pocket PLA

          本文链接:https://www.haomeiwen.com/subject/trzjgqtx.html