用户可以调用SparkSession的udf功能自定义函数
用户定义函数
-
加载json数据
val df = spark.read.json("files\\test.json")
-
注册用户定义函数
spark.udf.register("addName", (name: String) => "Name:" + name)
-
创建视图并查询
df.createOrReplaceTempView("test") val testDF = spark.sql("select addName(name), name from test") testDF.show() /* +-----------------+----+ |UDF:addName(name)|name| +-----------------+----+ | Name:adam|adam| | Name:brad|brad| | Name:carl|carl| +-----------------+----+ */
用户定义聚合函数(弱类型)
弱类型用户定义聚合函数通过继承UserDefinedAggregateFunction来实现用户自定义聚合函数。
需要实现如下方法:
-
inputSchema
:函数输入的数据结构 -
bufferSchema
: 计算过程中缓存的数据结构 -
dataType
:函数返回的数据类型 -
deterministic
:函数是否稳定 -
initialize
:计算前缓冲区的初始化 -
update
:更新缓冲区数据 -
merge
:合并缓冲区数据 -
evaluate
:计算结果
计算平均年龄,例子如下:
class AvgAge extends UserDefinedAggregateFunction {
override def inputSchema: StructType = new StructType().add("age", LongType)
override def bufferSchema: StructType = new StructType().add("sum", LongType).add("count", LongType)
override def dataType: DataType = DoubleType
override def deterministic: Boolean = true
override def initialize(buffer: MutableAggregationBuffer): Unit = {
buffer(0) = 0L
buffer(1) = 0L
}
override def update(buffer: MutableAggregationBuffer, input: Row): Unit = {
buffer(0) = buffer.getLong(0) + input.getLong(0)
buffer(1) = buffer.getLong(1) + 1L
}
override def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {
buffer1(0) = buffer1.getLong(0) + buffer2.getLong(0)
buffer1(1) = buffer1.getLong(1) + buffer2.getLong(1)
}
override def evaluate(buffer: Row): Any = (buffer.getLong(0) / buffer.getLong(1)).toDouble
}
使用定义好的聚合函数进行计算:
object UdafDemo {
def main(args: Array[String]): Unit = {
val spark = SparkSession.builder().master("local[*]").appName("UdafDemo").getOrCreate()
val df = spark.read.json("files\\test.json")
df.createOrReplaceTempView("test")
val avgAge = new AvgAge()
spark.udf.register("avgAge", avgAge)
val avgAgeDf = spark.sql("select avgAge(age) from test")
avgAgeDf.show()
}
}
/*
+-----------+
|avgage(age)|
+-----------+
| 17.0|
+-----------+
*/
用户定义聚合函数(强类型)
通过继承Aggregator[IN, BUF, OUT]
类可自定义强类型的聚合函数。
需要实现如下方法:
-
zero
:初始化缓冲区 -
reduce
:更新缓冲区 -
merge
:合并缓冲区 -
finish
:计算结果 -
bufferEncoder
:缓冲区编码器 -
outputEncoder
:输出编码器
注:编码器在Encoders类中可以找到不同类型的实现,对于自定义对象则选择Encoders.product
,而基本数据类型则选择对应的编码器,例如Double类型的选择Encoders.scalaDouble
计算平均年龄,例子如下:
-
定义Person样例类
case class Person(id: Long, name: String, age: Long)
-
定义缓冲区样例类
case class AvgBuffer(sum: Long, count: Int)
-
实现Aggregator
class MyAvgAge extends Aggregator[Person, AvgBuffer, Double] { override def zero: AvgBuffer = AvgBuffer(0, 0) override def reduce(b: AvgBuffer, a: Person): AvgBuffer = { b.sum += a.age b.count += + 1 b } override def merge(b1: AvgBuffer, b2: AvgBuffer): AvgBuffer = { b1.sum += b2.sum b1.count += b2.count b1 } override def finish(reduction: AvgBuffer): Double = reduction.sum.toDouble / reduction.count override def bufferEncoder: Encoder[AvgBuffer] = Encoders.product override def outputEncoder: Encoder[Double] = Encoders.scalaDouble }
-
使用自定义函数计算平均年龄
object Udaf1Demo { def main(args: Array[String]): Unit = { val spark = SparkSession.builder().master("local[*]").appName("Udaf1Demo").getOrCreate() import spark.implicits._ val df = spark.read.json("files\\test.json") val avgAge = new MyAvgAge val avgCol = avgAge.toColumn.name("avgAge") val ds = df.as[Person] ds.select(avgCol).show() } } /* +------------------+ | avgAge| +------------------+ |17.333333333333332| +------------------+ */
网友评论