美文网首页
SparkSQL自定义 UDF 函数median求中位数

SparkSQL自定义 UDF 函数median求中位数

作者: 程序员网址导航 | 来源:发表于2019-08-13 21:04 被阅读0次

    原文:SparkSQL自定义 UDF 函数median求中位数

    前言


    我的场景:提供一个聚合组件操作Spark的DataFrame,然后支持先分组在聚合的功能,这里聚合要求支持最大值个数、求和、去重后求和、均值、中位数、最大值、最小值、方差、标准差、唯一值个数、唯一值、归一化等。

    实现下来发现除中位数和归一化外其他聚合均有内置函数,实现起来也就很容易了。
    但是在分组后计算中位数这里卡了很长时间,最后的解决办法是:自定义一个UDF函数实现分组后中位数的计算

    自定义中位数函数:CustomMedian.scala

    /**
        * 自定义计算中位数聚合函数
        * qi.wang<Email>1124602935@qq.com</Email>
        */
      object CustomMedian extends UserDefinedAggregateFunction {
    
        override def inputSchema: StructType = StructType(StructField("input", StringType) :: Nil)
        override def bufferSchema: StructType = StructType(StructField("sum", StringType) :: StructField("count", StringType) :: Nil)
        override def dataType: DataType = DoubleType
        override def deterministic: Boolean = true // 聚合函数是否是幂等的,即相同输入是否总是能得到相同输出
    
        override def initialize(buffer: MutableAggregationBuffer): Unit = {
          buffer(0) = ""
        }
    
        override def update(buffer: MutableAggregationBuffer, input: Row): Unit = {
          if (!input.isNullAt(0)) {
            buffer(0) = buffer.get(0) + "," + input.get(0)
          }
        }
    
        override def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {
          buffer1.update(0, buffer1.get(0) + "," + buffer2.get(0))
        }
    
        override def evaluate(buffer: Row): Any = {
          val list = new util.ArrayList[Integer]
          val stringList:Array[String] = buffer.getString(0).split(",")
          for (s <- stringList) {
            if (StringUtils.isNotBlank(s))
              list.add(s.toInt)
          }
          Collections.sort(list)
          val size = list.size
          var num:Double = 0L
          if (size % 2 == 1) num = list.get(((size+1) / 2) - 1).toDouble
          if (size % 2 == 0) num = (list.get(size / 2 - 1) + list.get(size / 2)) / 2.00
          num
        }
      }
    

    函数测试

    1. 造一个数据文件:/tmp/data.csv, 内容如下
    id|name|mobile|idnumber
    10|aa|11111111111|111111111111111111
    12|bb|12321321321|213123123213333333
    13|aa|21312332322|333333333333333334
    15|dd|23114567888|872837482374932794
    17|bb|44444444444|827183787373733333
    18|bb|55555555555|823048320999399999
    
    1. 测试代码
    package www.relaxheart.cn
    
    import www.relaxheart.cn.CustomMedian
    import org.apache.spark.sql.{Row, SparkSession}
    import org.apache.spark.sql.types._
    import scala.util.Random
    
    
    /**
      * @author 王琦<QQ.Email>1124602935@qq.com</QQ.Email>
      * @date 19/8/13 下午20:33
      * @description
      */
    object MedianUDFTest extends App {
    
      val spark = SparkSession.builder().master("local[*]").appName("MedianUDFTest").config("spark.sql.crossJoin.enabled", "true").getOrCreate()
    
    // 读取data.csv得到RDD
      val rdd = spark.sparkContext.textFile("/tmp/data.csv")
    
      // 从第一行数据中获取最后转成的DataFrame应该有多少列 并给每一列命名
      val colNames = rdd.first.split("\\|")
    
      // 设置DataFrame的结构
      val schema = StructType(colNames.map(fieldName => StructField(fieldName, StringType)))
    
      // 对每一行的数据进行处理
      val rowRDD = rdd.filter(_.split("\\|")(0) != "id").map(_.split("\\|")).map(p => Row(p: _*))
    
      // 创建DataFrame
      val data = spark.createDataFrame(rowRDD, schema)
    
      // 创建临时表
      val tmpTable = "_table"+System.currentTimeMillis()+Random.nextInt(10000000)
      data.createOrReplaceTempView(tmpTable)
    
     // 这步很关键,注册我们的自定义中位数函数
      spark.udf.register("median",  CustomMedian)
    
      // 利用SparkSQL + 自定义中位数函数实现分组后求中位数
      // 这里对测试数据按name进行分组,然后组内id的中位数
      val medianGroupDF = spark.sql(s"select name , median(id) as median from $tmpTable group by name")
    
      // 打印分组中位数聚合结果
      medianGroupDF.show()
    }
    

    结果验证

    image.png

    看打印结果是符合我们预期的。

    个人博客网站:王琦的个人兴趣分享网站 | RelaxHeart网 | Tec博客

    相关文章

      网友评论

          本文标题:SparkSQL自定义 UDF 函数median求中位数

          本文链接:https://www.haomeiwen.com/subject/qacijctx.html