为什么要自定义函数
虽然官方提供的sql函数已经很多,并且很强大了,但是有时候并不是都能满足我们的业务需求。除此之外,编写自定义函数能够让我们更加了解官方给定函数的底层实现。
函数的分类
sql函数一共分为三类
-
UDF[一条数据,一个结果]
1)UDF:一行进入,一行出 -
UDAF[多条数据,一个结果,聚合函数]
1)UDAF:输入多行,返回一行。
2)Spark3.x推荐使用extends Aggregator自定义UDAF,属于强类型的Dataset方式。
3)Spark2.x使用extends UserDefinedAggregateFunction,属于弱类型的DataFrame -
UDTF[expload (spark不支持)]
输入一行,返回多行(Hive);
SparkSQL中没有UDTF,Spark中用flatMap即可实现该功能
如何自定义函数
步骤
1.定义一个函数
2.注册:sparkSession.udf.register("函数名称",对应的函数)
3.使用:在sql中进行使用
自定义UDF函数
需求:字符填充,长度由用户自定,填充字符由用户自定
如:customFill
("tom","*",8) ;结果 *****tom
- 创建SparkSession
val sparkSession=SparkSession.builder().master("local[4]").appName("test").getOrCreate()
- 准备测试数据
// 数据准备
val list=List(
Student(2,"绣花",16,"女",1),
Student(5,"翠花",19,"女",2),
Student(9,"王菲菲",20,"女",1),
Student(11,"小惠",23,"女",1),
Student(12,"梦雅",25,"女",3)
)
// 为了方便,定义了一个样例类
case class Student(id:Int,name:String,age:Int,sex:String,classId:Int)
- 将数据注册成表
// 导入隐式转换
import sparkSession.implicits._
// 转成 DataFrame
val frame: DataFrame = list.toDF()
// 注册成表
frame.createOrReplaceTempView("student")
- 自定义函数
/**
* 自定义sql函数
* @param coll 类名
* @param symbol 符号
* @param length 长度
*/
def customFill(coll:String,symbol:String,length:Int): String ={
if(coll.length>=length) coll
else {
symbol*(length-coll.length)+coll
}
}
5.注册自定义函数
name:第一个参数,给函数指定一个名称
func:将自定义函数传进去,
注意:以 def 声明的称为了方法
,方法转函数(其实都是一个意思),需要在后面接上_
//注册函数
sparkSession.udf.register("customFill",customFill _)
- 编写sql,调用自定义函数并执行
// 编写sql,
val frame1: DataFrame = sparkSession.sql(
"""
|select id,customFill(name,'*',8) as name,age,sex,classId from student
|""".stripMargin)
// 执行
frame1.show()
- 运行结果
+---+-----------+---+---+-------+
| id| name|age|sex|classId|
+---+-----------+---+---+-------+
| 2| ******绣花| 16| 女| 1|
| 5| ******翠花| 19| 女| 2|
| 9|*****王菲菲| 20| 女| 1|
| 11| ******小惠| 23| 女| 1|
| 12| ******梦雅| 25| 女| 3|
+---+-----------+---+---+-------+
- 完整代码
import org.apache.spark.sql.{DataFrame, SparkSession}
import org.junit.Test
class SparkFunction {
val sparkSession=SparkSession.builder().master("local[4]").appName("test").getOrCreate()
@Test
def demo01: Unit ={
// 数据准备
val list=List(
Student(2,"绣花",16,"女",1),
Student(5,"翠花",19,"女",2),
Student(9,"王菲菲",20,"女",1),
Student(11,"小惠",23,"女",1),
Student(12,"梦雅",25,"女",3)
)
// 导入隐式转换
import sparkSession.implicits._
// 转成 DataFrame
val frame: DataFrame = list.toDF()
// 注册成表
frame.createOrReplaceTempView("student")
//注册函数
sparkSession.udf.register("customFill",customFill _)
// 编写sql
val frame1: DataFrame = sparkSession.sql(
"""
|select id,customFill(name,'*',8) as name,age,sex,classId from student
|""".stripMargin)
// 执行
frame1.show()
}
case class Student(id:Int,name:String,age:Int,sex:String,classId:Int)
/**
* 自定义sql函数
* @param coll 类名
* @param symbol 符号
* @param length 长度
*/
def customFill(coll:String,symbol:String,length:Int): String ={
if(coll.length>=length) coll
else {
symbol*(length-coll.length)+coll
}
}
}
自定义UDAF函数
使用弱类型实现UDAF函数
步骤:
-
创建一个类
-
继承
UserDefinedAggregateFunction
抽象类(spark3.x
版本中已标志为过期) -
实现里面的抽象方法。
//指定输入列的参数类型;需要指定为
StructType
类型
override definputSchema
: StructType = ???
//指定中间变量的类型
override defbufferSchema
: StructType = ???
//指定聚合函数的结果类型
override defdataType
: DataType = ???
//一致性指定(是否以同样输入返回同样的结果)
override defdeterministic
: Boolean = ???
//初始化中间变量的值
override definitialize
(buffer: MutableAggregationBuffer): Unit = ???
//累加 [在每个task中执行]
override defupdate
(buffer: MutableAggregationBuffer, input: Row): Unit = ???
//合并所有task中该分组的所有的数据
override defmerge
(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = ???
//计算得到最终结果
override defevaluate
(buffer: Row): Any = ??? -
创建自定义UDAF对象
-
注册自定义函数
-
编写sql并使用
数据准备
// 数据准备
val list=List(
Student(2,"绣花",16,"女",1),
Student(5,"翠花",19,"女",2),
Student(9,"王菲菲",20,"女",1),
Student(11,"小惠",23,"女",1),
Student(12,"梦雅",25,"女",3)
)
需求
统计用户的平均年龄(总年龄/总人数
)
自定义函数
/**
* 使用弱类型定义UDAF函数
*/
class CustomUdafByWeak extends UserDefinedAggregateFunction{
/**
* 指定输入列的参数类型;需要指定为`StructType`类型
* @return
*/
override def inputSchema: StructType = {
// input 随便指定
// LongType 输入进来的是年龄,所以需要指定IntegerType或LongType类型,
val fields=Array(StructField("input",LongType))
StructType(fields)
}
/**
* 指定中间变量的类型
* @return
*/
override def bufferSchema: StructType = {
// 当接收到输入的年龄时,肯定需要存起来,记录年龄总和(sum),次数(count)等,方便最终求平均年龄
val fields=Array(
//定义,记录总年龄
StructField("sum",LongType),
//定义,记录次数
StructField("count",LongType)
)
StructType(fields)
}
/**
* 指定聚合函数的结果类型
* @return
*/
override def dataType: DataType = DoubleType
/**
* 一致性指定
* @return
*/
override def deterministic: Boolean = true
/**
* 初始化中间变量的值
* @param buffer
*/
override def initialize(buffer: MutableAggregationBuffer): Unit = {
// buffer中存放在中间变量数据
// 在 bufferSchema 中定义了中间变量的类型,此时需要对中间变量进行设置
// 默认的话,总年龄应该为0,总次数也应该为0
//如何获取?sum 和 count ?
// buffer.getAs[类型](根据角标取值)
//如何设置值呢?
// buffer(角标)= value
// buffer.update(角标,value)
buffer(0)= 0L // 总年龄
buffer(1)= 0L // 总次数
}
/**
* 累加 [在每个task中执行]
* @param buffer
* @param input
*/
override def update(buffer: MutableAggregationBuffer, input: Row): Unit = {
// buffer中存放在中间变量数据
// input 当前输入的年龄
// 获取 上一次sum
val preSum:Long = buffer.getAs[Long](0)
// 获取 上一次count
val preCount: Long = buffer.getAs[Long](1)
// 从input中取出年龄 在 inputSchema函数中只指定了一个参数,所以用角标0取值即可。
val age=input.getAs[Long](0)
//重新修改值
buffer.update(0,preSum+age)
buffer.update(1,preCount+1)
}
/**
* 合并所有task中该分组的所有的sum与count
* @param buffer1
* @param buffer2
*/
override def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {
// 同样的操作,取值赋值
// 获取 上一次sum
val preSum:Long = buffer1.getAs[Long](0)
// 获取 上一次count
val preCount: Long = buffer1.getAs[Long](1)
// 取各个分区的 sum 和 count
val partitionSum:Long = buffer2.getAs[Long](0)
val partitionCount:Long = buffer2.getAs[Long](1)
// 累加保存
buffer1.update(0,preSum+partitionSum)
buffer1.update(1,preCount+partitionCount)
}
/**
* 计算得到最终结果
* @param buffer
* @return
*/
override def evaluate(buffer: Row): Any = {
// 计算平均年龄
// sum / count =avg
// 总年龄
val sum=buffer.getAs[Long](0)
// 总次数
val count=buffer.getAs[Long](1)
sum.toDouble/count
}
}
测试
@Test
def demo02(): Unit ={
val sparkSession=SparkSession.builder().master("local[4]").appName("test").getOrCreate()
// 数据准备
val list=List(
Student(2,"绣花",16,"女",1),
Student(5,"翠花",19,"女",2),
Student(9,"王菲菲",20,"女",1),
Student(11,"小惠",23,"女",1),
Student(12,"梦雅",25,"女",3)
)
// 导入隐式转换
import sparkSession.implicits._
// 注册成表
val df: DataFrame = list.toDF("id","name","age","sex","class_id")
df.createOrReplaceTempView("student")
// 创建自定义UDAF对象
val fun=new CustomUdafByWeak
// 注册
sparkSession.udf.register("custom_avg",fun)
// 编写sql
val df2: DataFrame = sparkSession.sql(
"""
|select custom_avg(age) from student
|""".stripMargin)
df2.show()
}
结果
+-------------------------------------+
|customudafbyweak(CAST(age AS BIGINT))|
+-------------------------------------+
| 20.6|
+-------------------------------------+
使用强类型实现UDAF函数
@Stable
@deprecated("Aggregator[IN, BUF, OUT] should now be registered as a UDF" +
" via the functions.udaf(agg) method.", "3.0.0")
abstract class UserDefinedAggregateFunction extends Serializable {...}
在spark 3.x中UserDefinedAggregateFunction
已经被弃用了,目前推荐的是使用Aggregator[IN, BUF, OUT]
。
import org.apache.spark.sql.expressions.Aggregator
它需要我们指定三个类型(参数语义和上面是一样的)
IN
:输入类型
BUF
:中间类型
OUT
:最终输出类型
步骤:
-
定义
class
-
继承
Aggregator
指定IN
、BUF
、OUT
参数类型 -
重写内部方法
// 初始化中间变量
override defzero
: ParamBuff = ???
// 在每个分区中针对每个组进行合并
override defreduce
(b: ParamBuff, a: Long): ParamBuff = ???
// 在新的RDD分区中针对每个组的所有父RDD分区结果进行合并
override defmerge
(b1: ParamBuff, b2: ParamBuff): ParamBuff = ???
// 最终结果计算
override deffinish
(reduction: ParamBuff): Double = ???
// 指定中间变量的编码方式
override defbufferEncoder
: Encoder[ParamBuff] = ???
// 指定结果类型的编码方式
override defoutputEncoder
: Encoder[Double] = ??? -
创建自定义UDAF类
val fun = new CustomUdafByStrong
- 导入
import org.apache.spark.sql.functions._
转换成udaf
// 转换成udaf
import org.apache.spark.sql.functions._
// 创建自定义UDAF对象
val func = udaf(fun)
- 注册
parkSession.udf.register("custom_avg",func)
- 调用
自定义UDAF函数
/**
* 中间变量需要两个参数,
* @param sum 计算年龄总数
* @param count // 计算年龄个数
*/
case class ParamBuff(sum:Long,count:Long)
/**
* 使用强类型定义UDAF函数
*/
class CustomUdafByStrong extends Aggregator[Long,ParamBuff,Double]{
/**
* 初始化中间变量
* @return
*/
override def zero: ParamBuff = {
ParamBuff(0L,0L)
}
/**
* 在每个分区中针对每个组进行合并
* @param b ParamBuff 样例类
* @param a a 传入进来的年龄值
* @return
*/
override def reduce(b: ParamBuff, a: Long): ParamBuff = {
// 获取总年龄
ParamBuff(b.sum+a,b.count+1)
}
/**
* 在新的RDD分区中针对每个组的所有父RDD分区结果进行合并
* @param b1
* @param b2
* @return
*/
override def merge(b1: ParamBuff, b2: ParamBuff): ParamBuff = {
ParamBuff(b1.sum+b2.sum,b1.count+b2.count)
}
/**
* 最终结果计算
* @param reduction
* @return
*/
override def finish(reduction: ParamBuff): Double = {
reduction.sum.toDouble/reduction.count
}
/**
* 指定中间变量的编码方式
* @return
*/
override def bufferEncoder: Encoder[ParamBuff] = Encoders.product[ParamBuff]
/**
* 指定结果类型的编码方式
* @return
*/
override def outputEncoder: Encoder[Double] = Encoders.scalaDouble
}
测试
@Test
def demo03(): Unit ={
// 导入隐式转换
import sparkSession.implicits._
// 注册成表
val df: DataFrame = list.toDF("id","name","age","sex","class_id")
df.createOrReplaceTempView("student")
// 转换成udaf
import org.apache.spark.sql.functions._
// 创建自定义UDAF对象
val func = udaf(new CustomUdafByStrong)
// 注册
sparkSession.udf.register("custom_avg",func)
// 编写sql
val df2: DataFrame = sparkSession.sql(
"""
|select custom_avg(age) from student
|""".stripMargin)
df2.show()
}
结果
+---------------------------------------+
|customudafbystrong(CAST(age AS BIGINT))|
+---------------------------------------+
| 20.6|
+---------------------------------------+
网友评论