Spark 分布式实现距离判别分析
距离判别
设有两个总体G1,G2,从第一个总体抽取n个样本,从第二个总体中抽取m个样本,每个样本都有p个测量指标。取任一样本实测指标为X=(x1,x2,……xp)',分别计算X到两个总体的距离D1,D2,按距离最近准则判别归类。
距离计算公式(马氏距离):
判别公式:
判别分析
数据展示与说明
在这里插入图片描述某商场从市场随机抽取20中品牌的电视机进行调查,其中13中畅销,7种滞销。按照电视机的质量评分、功能评分、价格手机资料。其中“1”表示畅销,“2”表示滞销,根据该样本建立判别函数,对以后的新样本进行评测。
实现过程
首先也是要自定义一个计算样本均值向量的自定义聚合函数,同上一篇
然后按照公式进行计算
def main(args: Array[String]): Unit = {
val spark = SparkSession
.builder()
.appName(s"${this.getClass.getSimpleName}")
.master("local[*]")
.getOrCreate()
import spark.implicits._
val sc = spark.sparkContext
val irisData = spark.read
.option("header", true)
.option("inferSchema", true)
.csv("F:\\DataSource\\dda.txt")
val schema = irisData.schema
val fts = schema.filterNot(_.name == "class").map(_.name).toArray
val amountVectorAssembler: VectorAssembler = new VectorAssembler()
.setInputCols(fts)
.setOutputCol("features")
val vec2Array = udf((vec: DenseVector) => vec.toArray)
val irisFeatrus = amountVectorAssembler
.transform(irisData)
.select($"class", vec2Array($"features") as "features")
val ui = spark.udf.register("udafMedian", new meanVector(fts.length))
// 计算样本均值向量
val uiGroup = irisFeatrus
.groupBy($"class")
.agg(ui($"features") as "ui", count($"class") as "len")
// 类别、协方差矩阵、均值向量
val covMatrix = irisFeatrus
.join(uiGroup, "class")
.rdd
.map(row => {
val lable = row.getAs[String]("class")
val len = row.getAs[Long]("len")
val u = densevec(row.getAs[Seq[Double]]("ui").toArray)
val x = densevec(row.getAs[Seq[Double]]("features").toArray)
val denseMatrix = (x - u).toDenseMatrix
lable -> (denseMatrix, u, len)
})
.reduceByKey((d1, d2) => {
(DenseMatrix.vertcat(d1._1, d2._1), d1._2, d1._3)
})
.map(tp => {
val len = tp._2._3 - 1
val t: DenseMatrix[Double] = (tp._2._1.t * tp._2._1).map(x => x / len)
(tp._1, t, tp._2._2)
})
val covmBroad = sc.broadcast(covMatrix.collect())
// 定义判别函数
def dfunction(vec: Seq[Double]) = {
covmBroad.value
.map(tp => {
val xui = (densevec(vec.toArray) - tp._3).toDenseMatrix
val d = (xui * inv(tp._2) * xui.t).data.head
(d, tp._1)
})
.minBy(_._1)
._2
}
val nGudf = udf((vec: Seq[Double]) => dfunction(vec))
val predictions = irisFeatrus
.withColumn("nG", nGudf($"features"))
predictions.show(truncate = false)
spark.stop()
}
结果查看:从结果可以看到,仅有一列判别错误
在这里插入图片描述
参考资料:
《多元统计分析及R语言建模》 – 王斌会
网友评论