美文网首页
Spark的简单的自定义函数

Spark的简单的自定义函数

作者: 不愿透露姓名的李某某 | 来源:发表于2019-07-18 12:55 被阅读0次

    package Sparksql02

    import java.lang

    import org.apache.spark.sql.expressions.{MutableAggregationBuffer, UserDefinedAggregateFunction}

    import org.apache.spark.sql.types.{StructField, _}

    import org.apache.spark.sql.{Dataset, Row, SparkSession}

    object testPow {

    def main(args: Array[String]): Unit = {

    val spark=SparkSession.builder()

    .appName("testPow")

    .master("local[*]")

    .getOrCreate()

    val geomean=new GeoMean

    val range: Dataset[lang.Long] = spark.range(1,11)

    range.createTempView("v_range")

    //第一种方法

    //注册函数

    //    spark.udf.register("gm",geomean)

    //  //将range这个Dataset注册成视图

    //

    //    val result = spark.sql("SELECT gm(id) FROM v_range")

          import  spark.implicits._

    val result = range.groupBy().agg(geomean($"id").as("geomean"))

    result.show()

    spark.stop()

    }

    }

    class  GeoMeanextends  UserDefinedAggregateFunction{

    //输入数据的类型

      override def inputSchema: StructType = StructType(List(

    StructField("value",DoubleType)

    ))

    //产生中间结果的数据类型

      override def bufferSchema: StructType = StructType(List(

    //相乘之后返回的积

        StructField("product",DoubleType),

    //参与运算的个数

          StructField("count",LongType)

    ))

    //最终返回结果类型

      override def dataType: DataType = DoubleType

    //确保一致性,一般用true

      override def deterministic: Boolean =true

      //指定初始值

      override def initialize(buffer: MutableAggregationBuffer): Unit = {

    //相乘的初始值

        buffer(0)=1.0

        //参与运算数字的个数

        buffer(1)=0L

      }

    //每有一条数据参与运算就更新一下中间结果(update相当于在每一个分区的运算)

      override def update(buffer: MutableAggregationBuffer, input: Row): Unit = {

    //每有一个数字参与运算就进行相乘(包含历史结果)

        buffer(0)=buffer.getDouble(0)*input.getDouble(0)

    //参与运算数据的个数也有更新

        buffer(1)=buffer.getLong(1)+1L

      }

    //全局聚合

      override def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {

    //每个分区进行相乘

        buffer1(0)=buffer1.getDouble(0)*buffer2.getDouble(0)

    //每个分区参与预算的中间结果进行相加

        buffer1(1)=buffer1.getLong(1)+buffer2.getLong(1)

    }

    //计算最终的结果

      override def evaluate(buffer: Row): Double = {

    math.pow(buffer.getDouble(0),1.toDouble/buffer.getLong(1))

    }

    }

    相关文章

      网友评论

          本文标题:Spark的简单的自定义函数

          本文链接:https://www.haomeiwen.com/subject/hmtllctx.html