美文网首页
Spark 自定义聚合函数-求中位数

Spark 自定义聚合函数-求中位数

作者: 灬臣独秀灬 | 来源:发表于2020-03-12 14:36 被阅读0次

自定义聚合函数的场景

业务需要统计最接近两年某商品在门店销售价格的中位数

由于spark 原生并不支持这样的聚合操作,所这个时候自定义聚合函数产生了。
中位数:所有输入数据排序,取中间的一个结果,或者中间两个结果的平均数。

自定义聚合函数开发步骤

1、 自定义类 class,并且继承 UserDefinedAggregateFunction。
2、 重写父类方法、、以及属性。
3、 注册自方法 使用 session.udf.register。

实现类
package cn.harsons.mbd.fun

import org.apache.spark.sql.Row
import org.apache.spark.sql.expressions.{MutableAggregationBuffer, UserDefinedAggregateFunction}
import org.apache.spark.sql.types._

import scala.collection.mutable.ListBuffer

/**
  * 自定义聚合函数
  *
  * @author liyabin
  * @date 2020/3/11 0011
  */
class Middle extends UserDefinedAggregateFunction {

  /**
    * 分割字符串
    */
  val split_str = "_"

  // 输入值 类型
  override def inputSchema: StructType = StructType(StructField("data", DoubleType) :: Nil)

  // 缓冲类型
  override def bufferSchema: StructType = StructType(StructField("middle", StringType) :: Nil)

  // 返回值类型
  override def dataType: DataType = DoubleType

  //对于数据一样的情况下 返回值时候一样
  override def deterministic: Boolean = true

  /**
    * 初始化时调用
    *
    * @param buffer
    */
  override def initialize(buffer: MutableAggregationBuffer): Unit = {
    buffer.update(0, "")
  }

  /**
    * 一个节点统计操作,每次输入一行记录。需要根据旧的缓冲和新来的数据 做逻辑处理
    *
    * @param buffer 缓冲引用
    * @param input  新的值
    */
  override def update(buffer: MutableAggregationBuffer, input: Row): Unit = {
    buffer.update(0, buffer.get(0).asInstanceOf[String] + split_str + input.getDouble(0).toString)
  }

  /**
    * 多条记录时如何处理 -》 其实就是两个Node计算出来的结果合并操作
    *
    * @param buffer1 节点一的缓冲区
    * @param buffer2 节点二缓冲区
    */
  override def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {
    buffer1.update(0, buffer1.get(0).asInstanceOf[String] + split_str + buffer2.get(0).asInstanceOf[String])
  }

  /**
    * 最后输出 即 函数输出。 这里作用主要是取中位数。
    *
    * @param buffer 汇集后的缓冲区
    * @return
    */
  override def evaluate(buffer: Row): Any = {

    val str = buffer.get(0).asInstanceOf[String]
    val arrays = str.split(split_str)
    val list = new ListBuffer[Double]
    for (str <- arrays) {
      if (str != null && !str.isEmpty) {
        list.append(str.toDouble)
      }
    }
    if (list.isEmpty) {
      return null
    }
    val sorted = list.sorted
    var size = sorted.size
    size = sorted.size
    // 偶数
    if (size % 2 == 0) {
      val middle_first = size / 2
      val middle_second = (size / 2) - 1
      (sorted(middle_first) + sorted(middle_second)) / 2
    } else {
      sorted(size / 2)
    }
  }
}

执行查询
  def main(args: Array[String]): Unit = {
    val spark = SparkSession.builder().master("local[2]").getOrCreate()
    val middle = spark.udf.register("middle", new Middle)
    val data = spark.createDataFrame(Seq(
      ("篮球", 56.0), ("足球", 66.0), ("高尔夫", 666.0),
      ("篮球", 57.0), ("足球", 166.0), ("高尔夫", 424.0),
      ("篮球", 58.0), ("足球", 266.0), ("高尔夫", 369.0),
      ("篮球", 59.0), ("足球", 111.0), ("高尔夫", 99.0),
      ("篮球", 66.0), ("足球", 99.0), ("高尔夫", 100.0))).toDF("name", "price")
    data.createOrReplaceTempView("orders")
    spark.sql("select name , middle(price) as  middlePrice from orders group by name ").show(10)
    spark.stop()
  }
结果输出
image.png
踩过的坑

楼主也是刚接触Spark,刚接触这个自定义函数时使用的是强类型自定义聚合函数。当时是想着使用ListBuffer 还缓冲列中所有结果,发现使用ListBuffer Spark 在生成代码时会报错,类型不支持。后面改成弱类型的ObjectType 也是报错。最终无奈之下只能用String 拼接。拼接完后在切割。如果大佬有好的解决办法还请赐教 !

相关文章

网友评论

      本文标题:Spark 自定义聚合函数-求中位数

      本文链接:https://www.haomeiwen.com/subject/lkvujhtx.html