函数:http://spark.apache.org/docs/latest/api/sql/index.html
一、自定义函数简介
在Spark中,也支持Hive中的自定义函数。自定义函数大致可以分为三种:
UDF(User-Defined-Function),即最基本的自定义函数,类似to_char,to_date等
UDAF(User- Defined Aggregation Funcation),用户自定义聚合函数,类似在group by之后使用的sum,avg等
二、自定义UDF函数
自定义一个UDF函数需要继承UserDefinedAggregateFunction类,并实现其中的8个方法。
通过spark.udf.register("funcName", func) 来进行注册
使用:select funcName(name) from people 来直接使用
1. 匿名函数注册UDF
import org.apache.spark.SparkConf
import org.apache.spark.sql.SparkSession
object MySpark {
def main(args: Array[String]) {
// 定义应用名称
val conf = new SparkConf().setAppName("mySpark0")
conf.setMaster("spark://master:7077")
conf.setJars(Seq("/root/SparkTest.jar"))
// 创建SparkSession对象
val spark = SparkSession.builder()
.appName("DataFrameAPP")
.config(conf)
.getOrCreate()
// 构造测试数据,有两个字段、名字和年龄
val userData = Array(("Leo", 16), ("Marry", 21), ("Jack", 14), ("Tom", 18))
//创建测试df
val userDF = spark.createDataFrame(userData).toDF("name", "age")
userDF.show
// 注册一张user表
userDF.createOrReplaceTempView("user")
// 匿名方法注册函数
spark.udf.register("strLen", (str: String) => str.length())
// 函数使用
spark.sql("select name,strLen(name) as name_len from user").show
}
}
// 执行结果
+-----+--------+
| name|name_len|
+-----+--------+
| Leo| 3|
|Marry| 5|
| Jack| 4|
| Tom| 3|
+-----+--------+
2. 实名函数注册UDF
import org.apache.spark.SparkConf
import org.apache.spark.sql.SparkSession
object MySpark {
def main(args: Array[String]) {
// 定义应用名称
val conf = new SparkConf().setAppName("mySpark0")
conf.setMaster("spark://master:7077")
conf.setJars(Seq("/root/SparkTest.jar"))
// 创建SparkSession对象
val spark = SparkSession.builder()
.appName("DataFrameAPP")
.config(conf)
.getOrCreate()
// 构造测试数据,有两个字段、名字和年龄
val userData = Array(("Leo", 16), ("Marry", 21), ("Jack", 14), ("Tom", 18))
//创建测试df
val userDF = spark.createDataFrame(userData).toDF("name", "age")
userDF.show
// 注册一张user表
userDF.createOrReplaceTempView("user")
// 实名函数注册
spark.udf.register("isAdult", isAdult _)
// 使用
spark.sql("select name,isAdult(age) as age from user").show
}
/**
* 自定义函数
* 根据年龄大小返回是否成年 成年:true,未成年:false
*/
def isAdult(age: Int) = {
if (age < 18) {
false
} else {
true
}
}
}
// 执行结果
+-----+-----+
| name| age|
+-----+-----+
| Leo|false|
|Marry| true|
| Jack|false|
| Tom| true|
+-----+-----+
3. DataFrame用法
DataFrame的udf方法虽然和Spark Sql的名字一样,但是属于不同的类,它在org.apache.spark.sql.functions里。
import org.apache.spark.sql.functions._
//注册自定义函数(通过匿名函数)
val strLen = udf((str: String) => str.length())
//注册自定义函数(通过实名函数)
val udf_isAdult = udf(isAdult _)
示例:
import org.apache.spark.SparkConf
import org.apache.spark.sql.SparkSession
object MySpark {
def main(args: Array[String]) {
// 定义应用名称
val conf = new SparkConf().setAppName("mySpark0")
conf.setMaster("spark://master:7077")
conf.setJars(Seq("/root/SparkTest.jar"))
// 创建SparkSession对象
val spark = SparkSession.builder()
.appName("DataFrameAPP")
.config(conf)
.getOrCreate()
// 构造测试数据,有两个字段、名字和年龄
val userData = Array(("Leo", 16), ("Marry", 21), ("Jack", 14), ("Tom", 18))
//创建测试df
val userDF = spark.createDataFrame(userData).toDF("name", "age")
userDF.show
// 注册一张user表
userDF.createOrReplaceTempView("user")
import org.apache.spark.sql.functions._
// 注册自定义函数(通过匿名函数)
val strLen = udf((str: String) => str.length())
//注册自定义函数(通过实名函数)
val udf_isAdult = udf(isAdult _)
// 使用,通过withColumn添加列
// userDF.withColumn("name_len", strLen(col("name")))
// .withColumn("isAdult", udf_isAdult(col("age"))).show
// 通过select添加列
userDF.select(col("*"), strLen(col("name")) as "name_len",
udf_isAdult(col("age")) as "isAdult").show
}
/**
* 自定义函数
* 根据年龄大小返回是否成年 成年:true,未成年:false
*/
def isAdult(age: Int) = {
if (age < 18) {
false
} else {
true
}
}
}
执行结果:
+-----+---+--------+-------+
| name|age|name_len|isAdult|
+-----+---+--------+-------+
| Leo| 16| 3| false|
|Marry| 21| 5| true|
| Jack| 14| 4| false|
| Tom| 18| 3| true|
+-----+---+--------+-------+
三、自定义UDAF函数
UDAF:用户自定义的聚合函数,函数本身作用于数据集合,能够在具体操作的基础上进行自定义操作。
示例1:
update:各个分组的值内部聚合
merge:各个节点的同一分组的值聚合
evaluate:聚合各个分组的缓存值
自定义函数:
import org.apache.spark.sql.Row
import org.apache.spark.sql.expressions.{MutableAggregationBuffer, UserDefinedAggregateFunction}
import org.apache.spark.sql.types._
/**
* 单词数量统计
*/
object StringCount extends UserDefinedAggregateFunction {
/**
* inputSchema,指的是,输入数据的类型
*
* @return
*/
override def inputSchema: StructType = {
StructType(Array(StructField("str", StringType, true)))
}
/**
* bufferSchema,指的是,中间进行聚合时,所处理的数据的类型
*
* @return
*/
override def bufferSchema: StructType = {
// 计数为整型
StructType(Array(StructField("count", IntegerType, true)))
}
/**
* dataType,指的是,函数返回值的类型
*
* @return
*/
override def dataType: DataType = {
IntegerType
}
override def deterministic: Boolean = {
true
}
/**
* 为每个分组的数据执行初始化操作
*
* @param buffer
*/
override def initialize(buffer: MutableAggregationBuffer): Unit = {
buffer(0) = 0
}
/**
* 更新 可以认为一个一个地将组内的字段值传递进来 实现拼接的逻辑
* 相当于map端的combiner,combiner就是对每一个map task的处理结果进行一次小聚合
* 聚和发生在reduce端.
* 这里即是:在进行聚合的时候,每当有新的值进来,对分组后的聚合如何进行计算
* update的结果写入buffer中,每个分组中的每一行数据都要进行update操作
*/
override def update(buffer: MutableAggregationBuffer, input: Row): Unit = {
// 进行累加计数
buffer(0) = buffer.getAs[Int](0) + 1
}
/**
* 合并 update操作,可能是针对一个分组内的部分数据,在某个节点上发生的 但是可能一个分组内的数据,会分布在多个节点上处理
* 此时就要用merge操作,将各个节点上分布式拼接好的串,合并起来
* 这里即是:最后在分布式节点完成后需要进行全局级别的Merge操作
* 也可以是一个节点里面的多个executor合并 reduce端大聚合
* merge后的结果写如buffer1中
*/
override def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {
// 合并操作
buffer1(0) = buffer1.getAs[Int](0) + buffer2.getAs[Int](0)
}
/**
* 最后,指的是,一个分组的聚合值,如何通过中间的缓存聚合值,最后返回一个最终的聚合值
* @param buffer
* @return
*/
override def evaluate(buffer: Row): Any = {
// 返回结果
buffer.getAs[Int](0)
}
}
注册,测试:
import org.apache.spark.SparkConf
import org.apache.spark.sql.SparkSession
object MySpark {
def main(args: Array[String]) {
// 定义应用名称
val conf = new SparkConf().setAppName("mySpark0")
conf.setMaster("spark://master:7077")
conf.setJars(Seq("/root/SparkTest.jar"))
// 创建SparkSession对象
val spark = SparkSession.builder()
.appName("DataFrameAPP")
.config(conf)
.getOrCreate()
// 构造测试数据,有两个字段、名字和年龄
val userData = Array(("Leo", 16), ("Marry", 21), ("Jack", 14), ("Tom", 18))
//创建测试df
val userDF = spark.createDataFrame(userData).toDF("name", "age")
// userDF.show
// 注册一张user表
userDF.createOrReplaceTempView("user")
// 注册函数:SQLContext.udf.register(), 添加 StringCount对象
spark.udf.register("strCount", StringCount)
// 使用自定义函数,单词计数
spark.sql("select name,strCount(name) from user group by name").show
}
}
因为新增加函数类,如果报错:
Caused by: java.lang.ClassCastException: cannot assign instance of scala.collection.immutable.List$SerializationProxy to field org.apache.spark.sql.execution.aggregate.HashAggregateExec.aggregateExpressions of type scala.collection.Seq
请重新打包,拷贝jar文件到D:\root:
目标
示例2:
自定义函数:
import org.apache.spark.sql.Row
import org.apache.spark.sql.expressions.{MutableAggregationBuffer, UserDefinedAggregateFunction}
import org.apache.spark.sql.types._
/**
* 求平均数
*/
class CustomerAvg extends UserDefinedAggregateFunction {
//输入的类型
override def inputSchema: StructType = StructType(StructField("salary", LongType) :: Nil)
//缓存数据的类型
override def bufferSchema: StructType = {
StructType(StructField("sum", LongType) :: StructField("count", LongType) :: Nil)
}
//返回值类型
override def dataType: DataType = LongType
//幂等性
override def deterministic: Boolean = true
//初始化
override def initialize(buffer: MutableAggregationBuffer): Unit = {
buffer(0) = 0L
buffer(1) = 0L
}
//更新 分区内操作
override def update(buffer: MutableAggregationBuffer, input: Row): Unit = {
buffer(0)=buffer.getLong(0) +input.getLong(0)
buffer(1)=buffer.getLong(1)+1L
}
//合并 分区与分区之间操作
override def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {
buffer1(0)=buffer1.getLong(0)+buffer2.getLong(0)
buffer1(1)=buffer1.getLong(1)+buffer2.getLong(1)
}
//最终执行的方法
override def evaluate(buffer: Row): Any = {
buffer.getLong(0)/buffer.getLong(1)
}
}
测试:
import org.apache.spark.SparkConf
import org.apache.spark.sql.SparkSession
object MySpark {
def main(args: Array[String]) {
// 定义应用名称
val conf = new SparkConf().setAppName("mySpark0")
conf.setMaster("spark://master:7077")
conf.setJars(Seq("/root/SparkTest.jar"))
// 创建SparkSession对象
val spark = SparkSession.builder()
.appName("DataFrameAPP")
.config(conf)
.getOrCreate()
// 构造测试数据,有两个字段、名字和年龄
val userData = Array(("Leo", 16), ("Marry", 21), ("Jack", 14), ("Tom", 18))
//创建测试df
val userDF = spark.createDataFrame(userData).toDF("name", "age")
// userDF.show
// 注册一张user表
userDF.createOrReplaceTempView("user")
spark.udf.register("MyAvg",new CustomerAvg)
// 测试
spark.sql("select MyAvg(age) avg_age from user").show()
}
}
+-------+
|avg_age|
+-------+
| 17|
+-------+
网友评论