美文网首页数据挖掘
fp_growth频繁项集和关联规则Spark ML调用实现

fp_growth频繁项集和关联规则Spark ML调用实现

作者: xiaogp | 来源:发表于2020-11-22 11:05 被阅读0次

    摘要:关联规则置信度支持度提升度规则集数据挖掘Spark

    关联规则

    关联规则是基于统计的无监督学习方法,它基于序列挖掘频繁出现因素组合的模式,进而可以推断出如果出现了A,B,还可能出现C的规则,可以使用的场景包括二分类中需要找到规则集,在推荐中做关联推荐等。
    关联规则的研究对象是事件序列,目的是找到频繁事件组合(项集),用支持度来衡量出现的频数强度,一个频繁项集内部也分为前项后项,为了描述前项的出现推断后项的能力强弱引出置信度,即等于在前项出现的情况下后项出现的比例,再引出提升度,即因为前项的出现导致后项比随机出现概率提升的倍数。

    Spark ML代码实现

    算法接受DataFrame输入,指定输入序列字段,支持度,置信度,freqItemsets输出频繁项集,associationRules输出关联规则,序列字段由groupBy+collect_list构造得到,相当于将一个对象的所有元素聚合成一个序列,transform可以对新dataframe做预测,推荐频繁项集内未出现的元素

    import org.apache.spark.sql.functions._
    import org.apache.spark.sql.SparkSession
    import org.apache.spark.ml.fpm.FPGrowth
    import org.apache.spark.ml.fpm.FPGrowthModel
    
    object FpGrowthExample {
      val spark: SparkSession = SparkSession.builder().appName("FpGrowthExample").master("yarn").getOrCreate()
      import spark.implicits._
    
      def main(args: Array[String]): Unit = {
        val df = spark.read.format("csv").option("header", true).load("/user/test/data.txt")
        // 过滤热门词
        val df2 = df.filter(!$"label_value".isin("其他", "有登记联系方式")).distinct()
        df2.cache()
        val ts = df2.groupBy("ent_name").agg(collect_list("label_value").alias("label_value_list"))
    
        // 定义模型阈值
        val model = new FPGrowth()
          .setItemsCol("label_value_list")
          .setMinConfidence(0.5)
          .setMinSupport(0.1)
          .fit(ts)
    
        model.write.overwrite().save("/user/test/SparkMLModel/fpgrowth")
    
        // 载入模型
        val model2 = FPGrowthModel.load("/user/test/SparkMLModel/fpgrowth")
    
        // 查看频繁项集
        val freq = model2.freqItemsets
        // 只查看组合多于1个的项集
        val myfunc1 = udf((x: Any) => {
          val tmp = x.asInstanceOf[scala.collection.mutable.WrappedArray[String]]
          tmp.size > 1
        })
        val freq2 = freq.filter(myfunc1($"items"))
    
        // 查看置信度(关联规则)
        val conf = model2.associationRules
    
        // 输出格式
        val myfunc2 = udf((x: Any) => x.asInstanceOf[scala.collection.mutable.WrappedArray[String]](0))
        val myfunc3 = udf((x: Any) => x.asInstanceOf[scala.collection.mutable.WrappedArray[String]].mkString("+"))
        val conf2 = conf.withColumn("antecedent", myfunc3($"antecedent")).withColumn("consequent", myfunc2($"consequent"))
        // 加入提升度
        val df3 = df2.groupBy("label_value").count()
        val count = df2.select($"ent_name").distinct().count()
        val df4 = df3.withColumn("base", $"count" / count)
        val conf3 = conf2.join(df4, $"consequent" === $"label_value", "left")
        val conf4 = conf3.withColumn("lift", $"confidence" / $"base").select($"antecedent", $"consequent", $"confidence", $"lift")
          .sort($"confidence".desc)
        val freq3 = freq2.withColumn("items", myfunc3($"items"))
    
        // 输出
        conf4.repartition(1).write.format("csv").mode("overwrite").save("/user/test/fpgrowth/conf")
        freq3.repartition(1).write.format("csv").mode("overwrite").save("/user/test/fpgrowth/freq")
    
        conf4.show(10, false)
        spark.stop()
      }
    }
    

    频繁项集:显示序列和频数


    频繁项集.png

    关联规则:其中antecedent代表前项,consequent代表后项,confidencelift分别是置信度和提升度

    关联规则.png

    相关文章

      网友评论

        本文标题:fp_growth频繁项集和关联规则Spark ML调用实现

        本文链接:https://www.haomeiwen.com/subject/vffjiktx.html