class MyUDF2 extends UDF2[String,Int,String]{
override def call(t1: String, t2: Int): String = {
val result = t1.concat(" and ").concat(t2.toString)
result
}
}
object UdfDemoTest {
def main(args: Array[String]): Unit = {
val conf = new SparkConf()
.setMaster("local")
.setAppName("UdfDemoTest")
val sc = new SparkContext(conf)
val sqlContext = new SQLContext(sc)
import sqlContext.implicits._
// 构造测试数据,有两个字段、名字和年龄
val userData = Array(("Leo", 16), ("Marry", 21), ("Jack", 14), ("Tom", 18))
//创建测试df
val userDF = sc.parallelize(userData).toDF("name", "age")
+-----+---+
|name |age|
+-----+---+
|Leo |16 |
|Marry|21 |
|Jack |14 |
|Tom |18 |
+-----+---+
// 注册一张user表
userDF.registerTempTable("user")
val myTest = new MyUDF2
sqlContext.udf.register("test",myTest ,DataTypes.StringType)
sqlContext.sql("select *, test(name,age) as n from user").show
+-----+---+------------+
| name|age| n|
+-----+---+------------+
| Leo| 16| Leo and 16|
|Marry| 21|Marry and 21|
| Jack| 14| Jack and 14|
| Tom| 18| Tom and 18|
+-----+---+------------+
sc.stop()
}
}
另参考
import com.alibaba.fastjson.JSON
import org.apache.spark.sql.api.java.UDF2
/**
* UDF2<String, String, String>
* 前两个类型是值传进来的值的类型
* 第一个类型代表json格式的字符串
* 第二个代表要获取字段值的字段名称
* 第三个类型代表返回json串里的某个字段的值
*/
class GetJsonObjectUDF extends UDF2[String,String,String] {
override def call(json: String, field: String): String = {
try {
val jsonObject = JSON.parseObject(json)
return jsonObject.getString(field)
} catch {
case e: Exception =>
e.printStackTrace()
}
null
}
}
数据如下:
+---------+---------------------------------+
|id |lables |
+---------+---------------------------------+
|786 |[1829, 42092, 1766, 179, 1769] |
|185 |[42059, 1748, 1787, 42092] |
|324 |[42059, 1748, 122, 1766] |
|541 |[1763, 42092, 146, 1766, 1775] |
|143 |[1763, 42092, 146, 1814, 42116] |
|572 |[1829, 42092, 1766, 42086, 1769] |
|140 |[1778, 1748, 1787, 42059] |
|184 |[1829, 1763, 42092, 1766 179] |
选出lables标签中包含1766,179的数据
selectDFWithLables(hiveContext,df1,"1766,179")
// 根据 DFrame 的lables array<String> 取数据
def selectDFWithLables(hiveContext: HiveContext, dataframe1: DataFrame, lables: String) = {
import hiveContext.implicits._
val tmpDF = dataframe1.withColumn("iscontains", lable_array_contains(col("lables"), lit(lables)))
val tmpSelectDF = tmpDF.filter($"iscontains" === true)
val resultDF = tmpSelectDF.drop("iscontains")
resultDF
}
/**
* 自定义函数,判断labels列是否包含字符串中的内容
*/
def lable_array_contains: UserDefinedFunction = udf[Boolean, collection.mutable.WrappedArray[String], String] {
(list1, str) => {
val list2 = str.split(",")
var ifExist = false
var flag = true
if (list1.isEmpty) flag = false
for (field <- list2 if flag) {
if (list1.contains(field)) {
ifExist = true
flag = false
}
}
ifExist
}
}
网友评论