美文网首页
Spark 自定义聚合函数-求中位数

Spark 自定义聚合函数-求中位数

作者: 灬臣独秀灬 | 来源:发表于2020-03-12 14:36 被阅读0次

    自定义聚合函数的场景

    业务需要统计最接近两年某商品在门店销售价格的中位数

    由于spark 原生并不支持这样的聚合操作,所这个时候自定义聚合函数产生了。
    中位数:所有输入数据排序,取中间的一个结果,或者中间两个结果的平均数。

    自定义聚合函数开发步骤

    1、 自定义类 class,并且继承 UserDefinedAggregateFunction。
    2、 重写父类方法、、以及属性。
    3、 注册自方法 使用 session.udf.register。

    实现类
    package cn.harsons.mbd.fun
    
    import org.apache.spark.sql.Row
    import org.apache.spark.sql.expressions.{MutableAggregationBuffer, UserDefinedAggregateFunction}
    import org.apache.spark.sql.types._
    
    import scala.collection.mutable.ListBuffer
    
    /**
      * 自定义聚合函数
      *
      * @author liyabin
      * @date 2020/3/11 0011
      */
    class Middle extends UserDefinedAggregateFunction {
    
      /**
        * 分割字符串
        */
      val split_str = "_"
    
      // 输入值 类型
      override def inputSchema: StructType = StructType(StructField("data", DoubleType) :: Nil)
    
      // 缓冲类型
      override def bufferSchema: StructType = StructType(StructField("middle", StringType) :: Nil)
    
      // 返回值类型
      override def dataType: DataType = DoubleType
    
      //对于数据一样的情况下 返回值时候一样
      override def deterministic: Boolean = true
    
      /**
        * 初始化时调用
        *
        * @param buffer
        */
      override def initialize(buffer: MutableAggregationBuffer): Unit = {
        buffer.update(0, "")
      }
    
      /**
        * 一个节点统计操作,每次输入一行记录。需要根据旧的缓冲和新来的数据 做逻辑处理
        *
        * @param buffer 缓冲引用
        * @param input  新的值
        */
      override def update(buffer: MutableAggregationBuffer, input: Row): Unit = {
        buffer.update(0, buffer.get(0).asInstanceOf[String] + split_str + input.getDouble(0).toString)
      }
    
      /**
        * 多条记录时如何处理 -》 其实就是两个Node计算出来的结果合并操作
        *
        * @param buffer1 节点一的缓冲区
        * @param buffer2 节点二缓冲区
        */
      override def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {
        buffer1.update(0, buffer1.get(0).asInstanceOf[String] + split_str + buffer2.get(0).asInstanceOf[String])
      }
    
      /**
        * 最后输出 即 函数输出。 这里作用主要是取中位数。
        *
        * @param buffer 汇集后的缓冲区
        * @return
        */
      override def evaluate(buffer: Row): Any = {
    
        val str = buffer.get(0).asInstanceOf[String]
        val arrays = str.split(split_str)
        val list = new ListBuffer[Double]
        for (str <- arrays) {
          if (str != null && !str.isEmpty) {
            list.append(str.toDouble)
          }
        }
        if (list.isEmpty) {
          return null
        }
        val sorted = list.sorted
        var size = sorted.size
        size = sorted.size
        // 偶数
        if (size % 2 == 0) {
          val middle_first = size / 2
          val middle_second = (size / 2) - 1
          (sorted(middle_first) + sorted(middle_second)) / 2
        } else {
          sorted(size / 2)
        }
      }
    }
    
    
    执行查询
      def main(args: Array[String]): Unit = {
        val spark = SparkSession.builder().master("local[2]").getOrCreate()
        val middle = spark.udf.register("middle", new Middle)
        val data = spark.createDataFrame(Seq(
          ("篮球", 56.0), ("足球", 66.0), ("高尔夫", 666.0),
          ("篮球", 57.0), ("足球", 166.0), ("高尔夫", 424.0),
          ("篮球", 58.0), ("足球", 266.0), ("高尔夫", 369.0),
          ("篮球", 59.0), ("足球", 111.0), ("高尔夫", 99.0),
          ("篮球", 66.0), ("足球", 99.0), ("高尔夫", 100.0))).toDF("name", "price")
        data.createOrReplaceTempView("orders")
        spark.sql("select name , middle(price) as  middlePrice from orders group by name ").show(10)
        spark.stop()
      }
    
    结果输出
    image.png
    踩过的坑

    楼主也是刚接触Spark,刚接触这个自定义函数时使用的是强类型自定义聚合函数。当时是想着使用ListBuffer 还缓冲列中所有结果,发现使用ListBuffer Spark 在生成代码时会报错,类型不支持。后面改成弱类型的ObjectType 也是报错。最终无奈之下只能用String 拼接。拼接完后在切割。如果大佬有好的解决办法还请赐教 !

    相关文章

      网友评论

          本文标题:Spark 自定义聚合函数-求中位数

          本文链接:https://www.haomeiwen.com/subject/lkvujhtx.html