美文网首页
Spark开发--Spark SQL--自定义函数(十五)

Spark开发--Spark SQL--自定义函数(十五)

作者: 无剑_君 | 来源:发表于2020-04-14 16:23 被阅读0次

    函数:http://spark.apache.org/docs/latest/api/sql/index.html

    一、自定义函数简介

    在Spark中,也支持Hive中的自定义函数。自定义函数大致可以分为三种:
    UDF(User-Defined-Function),即最基本的自定义函数,类似to_char,to_date等
    UDAF(User- Defined Aggregation Funcation),用户自定义聚合函数,类似在group by之后使用的sum,avg等

    二、自定义UDF函数

    自定义一个UDF函数需要继承UserDefinedAggregateFunction类,并实现其中的8个方法。
    通过spark.udf.register("funcName", func) 来进行注册
    使用:select funcName(name) from people 来直接使用

    1. 匿名函数注册UDF

    import org.apache.spark.SparkConf
    import org.apache.spark.sql.SparkSession
    
    object MySpark {
    
      def main(args: Array[String]) {
        // 定义应用名称
        val conf = new SparkConf().setAppName("mySpark0")
        conf.setMaster("spark://master:7077")
        conf.setJars(Seq("/root/SparkTest.jar"))
    
        // 创建SparkSession对象
        val spark = SparkSession.builder()
          .appName("DataFrameAPP")
          .config(conf)
          .getOrCreate()
    
        // 构造测试数据,有两个字段、名字和年龄
        val userData = Array(("Leo", 16), ("Marry", 21), ("Jack", 14), ("Tom", 18))
        //创建测试df
        val userDF = spark.createDataFrame(userData).toDF("name", "age")
        userDF.show
        // 注册一张user表
        userDF.createOrReplaceTempView("user")
        // 匿名方法注册函数
        spark.udf.register("strLen", (str: String) => str.length())
        // 函数使用
        spark.sql("select name,strLen(name) as name_len from user").show
    
      }
    }
    
    // 执行结果
    +-----+--------+
    | name|name_len|
    +-----+--------+
    |  Leo|       3|
    |Marry|       5|
    | Jack|       4|
    |  Tom|       3|
    +-----+--------+
    

    2. 实名函数注册UDF

    import org.apache.spark.SparkConf
    import org.apache.spark.sql.SparkSession
    
    object MySpark {
    
      def main(args: Array[String]) {
        // 定义应用名称
        val conf = new SparkConf().setAppName("mySpark0")
        conf.setMaster("spark://master:7077")
        conf.setJars(Seq("/root/SparkTest.jar"))
    
        // 创建SparkSession对象
        val spark = SparkSession.builder()
          .appName("DataFrameAPP")
          .config(conf)
          .getOrCreate()
    
        // 构造测试数据,有两个字段、名字和年龄
        val userData = Array(("Leo", 16), ("Marry", 21), ("Jack", 14), ("Tom", 18))
        //创建测试df
        val userDF = spark.createDataFrame(userData).toDF("name", "age")
        userDF.show
        // 注册一张user表
        userDF.createOrReplaceTempView("user")
        // 实名函数注册
        spark.udf.register("isAdult", isAdult _)
        // 使用
        spark.sql("select name,isAdult(age) as age from user").show
      }
      /**
        * 自定义函数
        * 根据年龄大小返回是否成年 成年:true,未成年:false
        */
      def isAdult(age: Int) = {
        if (age < 18) {
          false
        } else {
          true
        }
      }
    }
    
    // 执行结果
    +-----+-----+
    | name|  age|
    +-----+-----+
    |  Leo|false|
    |Marry| true|
    | Jack|false|
    |  Tom| true|
    +-----+-----+
    

    3. DataFrame用法

    DataFrame的udf方法虽然和Spark Sql的名字一样,但是属于不同的类,它在org.apache.spark.sql.functions里。

    import org.apache.spark.sql.functions._
    //注册自定义函数(通过匿名函数)
    val strLen = udf((str: String) => str.length())
    //注册自定义函数(通过实名函数)
    val udf_isAdult = udf(isAdult _)
    
    

    示例:

    import org.apache.spark.SparkConf
    import org.apache.spark.sql.SparkSession
    
    object MySpark {
    
      def main(args: Array[String]) {
        // 定义应用名称
        val conf = new SparkConf().setAppName("mySpark0")
        conf.setMaster("spark://master:7077")
        conf.setJars(Seq("/root/SparkTest.jar"))
    
        // 创建SparkSession对象
        val spark = SparkSession.builder()
          .appName("DataFrameAPP")
          .config(conf)
          .getOrCreate()
    
        // 构造测试数据,有两个字段、名字和年龄
        val userData = Array(("Leo", 16), ("Marry", 21), ("Jack", 14), ("Tom", 18))
        //创建测试df
        val userDF = spark.createDataFrame(userData).toDF("name", "age")
        userDF.show
        // 注册一张user表
        userDF.createOrReplaceTempView("user")
        import org.apache.spark.sql.functions._
        // 注册自定义函数(通过匿名函数)
        val strLen = udf((str: String) => str.length())
        //注册自定义函数(通过实名函数)
        val udf_isAdult = udf(isAdult _)
        // 使用,通过withColumn添加列
    //    userDF.withColumn("name_len", strLen(col("name")))
    //      .withColumn("isAdult", udf_isAdult(col("age"))).show
        // 通过select添加列
        userDF.select(col("*"), strLen(col("name")) as "name_len",
          udf_isAdult(col("age")) as "isAdult").show
      }
      /**
        * 自定义函数
        * 根据年龄大小返回是否成年 成年:true,未成年:false
        */
      def isAdult(age: Int) = {
        if (age < 18) {
          false
        } else {
          true
        }
      }
    }
    

    执行结果:

    +-----+---+--------+-------+
    | name|age|name_len|isAdult|
    +-----+---+--------+-------+
    |  Leo| 16|       3|  false|
    |Marry| 21|       5|   true|
    | Jack| 14|       4|  false|
    |  Tom| 18|       3|   true|
    +-----+---+--------+-------+
    

    三、自定义UDAF函数

    UDAF:用户自定义的聚合函数,函数本身作用于数据集合,能够在具体操作的基础上进行自定义操作。
    示例1:
    update:各个分组的值内部聚合
    merge:各个节点的同一分组的值聚合
    evaluate:聚合各个分组的缓存值
    自定义函数:

    import org.apache.spark.sql.Row
    import org.apache.spark.sql.expressions.{MutableAggregationBuffer, UserDefinedAggregateFunction}
    import org.apache.spark.sql.types._
    
    /**
      *  单词数量统计
      */
    object  StringCount extends UserDefinedAggregateFunction {
      /**
        * inputSchema,指的是,输入数据的类型
        *
        * @return
        */
      override def inputSchema: StructType = {
        StructType(Array(StructField("str", StringType, true)))
      }
    
      /**
        * bufferSchema,指的是,中间进行聚合时,所处理的数据的类型
        *
        * @return
        */
      override def bufferSchema: StructType = {
        // 计数为整型
        StructType(Array(StructField("count", IntegerType, true)))
      }
    
      /**
        * dataType,指的是,函数返回值的类型
        *
        * @return
        */
      override def dataType: DataType = {
        IntegerType
      }
    
      override def deterministic: Boolean = {
        true
      }
    
      /**
        * 为每个分组的数据执行初始化操作
        *
        * @param buffer
        */
      override def initialize(buffer: MutableAggregationBuffer): Unit = {
        buffer(0) = 0
      }
    
      /**
        * 更新 可以认为一个一个地将组内的字段值传递进来 实现拼接的逻辑
        * 相当于map端的combiner,combiner就是对每一个map task的处理结果进行一次小聚合
        * 聚和发生在reduce端.
        * 这里即是:在进行聚合的时候,每当有新的值进来,对分组后的聚合如何进行计算
        * update的结果写入buffer中,每个分组中的每一行数据都要进行update操作
        */
      override def update(buffer: MutableAggregationBuffer, input: Row): Unit = {
        // 进行累加计数
        buffer(0) = buffer.getAs[Int](0) + 1
      }
    
      /**
        * 合并 update操作,可能是针对一个分组内的部分数据,在某个节点上发生的 但是可能一个分组内的数据,会分布在多个节点上处理
        * 此时就要用merge操作,将各个节点上分布式拼接好的串,合并起来
        * 这里即是:最后在分布式节点完成后需要进行全局级别的Merge操作
        * 也可以是一个节点里面的多个executor合并 reduce端大聚合
        * merge后的结果写如buffer1中
        */
      override def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {
        // 合并操作
        buffer1(0) = buffer1.getAs[Int](0) + buffer2.getAs[Int](0)
      }
    
      /**
        * 最后,指的是,一个分组的聚合值,如何通过中间的缓存聚合值,最后返回一个最终的聚合值
        * @param buffer
        * @return
        */
      override def evaluate(buffer: Row): Any = {
        // 返回结果
        buffer.getAs[Int](0)
      }
    }
    
    

    注册,测试:

    import org.apache.spark.SparkConf
    import org.apache.spark.sql.SparkSession
    
    object MySpark {
    
      def main(args: Array[String]) {
        // 定义应用名称
        val conf = new SparkConf().setAppName("mySpark0")
        conf.setMaster("spark://master:7077")
        conf.setJars(Seq("/root/SparkTest.jar"))
    
        // 创建SparkSession对象
        val spark = SparkSession.builder()
          .appName("DataFrameAPP")
          .config(conf)
          .getOrCreate()
    
        // 构造测试数据,有两个字段、名字和年龄
        val userData = Array(("Leo", 16), ("Marry", 21), ("Jack", 14), ("Tom", 18))
        //创建测试df
        val userDF = spark.createDataFrame(userData).toDF("name", "age")
        //    userDF.show
        // 注册一张user表
        userDF.createOrReplaceTempView("user")
        // 注册函数:SQLContext.udf.register(), 添加 StringCount对象
        spark.udf.register("strCount", StringCount)
        // 使用自定义函数,单词计数
        spark.sql("select name,strCount(name) from user group by name").show
      }
    }
    

    因为新增加函数类,如果报错:
    Caused by: java.lang.ClassCastException: cannot assign instance of scala.collection.immutable.List$SerializationProxy to field org.apache.spark.sql.execution.aggregate.HashAggregateExec.aggregateExpressions of type scala.collection.Seq
    请重新打包,拷贝jar文件到D:\root:

    原jar
    目标

    示例2:
    自定义函数:

    import org.apache.spark.sql.Row
    import org.apache.spark.sql.expressions.{MutableAggregationBuffer, UserDefinedAggregateFunction}
    import org.apache.spark.sql.types._
    
    /**
      *  求平均数
      */
    class CustomerAvg extends UserDefinedAggregateFunction {
      //输入的类型
      override def inputSchema: StructType = StructType(StructField("salary", LongType) :: Nil)
      //缓存数据的类型
      override def bufferSchema: StructType = {
        StructType(StructField("sum", LongType) :: StructField("count", LongType) :: Nil)
      }
      //返回值类型
      override def dataType: DataType = LongType
      //幂等性
      override def deterministic: Boolean = true
      //初始化
      override def initialize(buffer: MutableAggregationBuffer): Unit = {
        buffer(0) = 0L
        buffer(1) = 0L
      }
      //更新 分区内操作
      override def update(buffer: MutableAggregationBuffer, input: Row): Unit = {
        buffer(0)=buffer.getLong(0) +input.getLong(0)
        buffer(1)=buffer.getLong(1)+1L
      }
      //合并 分区与分区之间操作
      override def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {
        buffer1(0)=buffer1.getLong(0)+buffer2.getLong(0)
        buffer1(1)=buffer1.getLong(1)+buffer2.getLong(1)
      }
      //最终执行的方法
      override def evaluate(buffer: Row): Any = {
        buffer.getLong(0)/buffer.getLong(1)
      }
    }
    

    测试:

    import org.apache.spark.SparkConf
    import org.apache.spark.sql.SparkSession
    
    object MySpark {
    
      def main(args: Array[String]) {
        // 定义应用名称
        val conf = new SparkConf().setAppName("mySpark0")
        conf.setMaster("spark://master:7077")
        conf.setJars(Seq("/root/SparkTest.jar"))
    
        // 创建SparkSession对象
        val spark = SparkSession.builder()
          .appName("DataFrameAPP")
          .config(conf)
          .getOrCreate()
    
        // 构造测试数据,有两个字段、名字和年龄
        val userData = Array(("Leo", 16), ("Marry", 21), ("Jack", 14), ("Tom", 18))
        //创建测试df
        val userDF = spark.createDataFrame(userData).toDF("name", "age")
        //  userDF.show
        // 注册一张user表
        userDF.createOrReplaceTempView("user")
        spark.udf.register("MyAvg",new CustomerAvg)
        // 测试
        spark.sql("select MyAvg(age) avg_age from user").show()
      }
    }
    
    +-------+
    |avg_age|
    +-------+
    |     17|
    +-------+
    

    相关文章

      网友评论

          本文标题:Spark开发--Spark SQL--自定义函数(十五)

          本文链接:https://www.haomeiwen.com/subject/letcoctx.html