UDAF 就是一个多行导成一行的聚合函数,它的过程与MR过程紧密结合

源码0:sum实现
https://www.cnblogs.com/yin-fei/p/10879731.html 这个例子比较简单
源码1:用于统计字符串字符个数的UDAF
import org.apache.hadoop.hive.ql.exec.UDFArgumentTypeException;
import org.apache.hadoop.hive.ql.metadata.HiveException;
import org.apache.hadoop.hive.ql.parse.SemanticException;
import org.apache.hadoop.hive.ql.udf.generic.AbstractGenericUDAFResolver;
import org.apache.hadoop.hive.ql.udf.generic.GenericUDAFEvaluator;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory;
import org.apache.hadoop.hive.serde2.objectinspector.PrimitiveObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory.ObjectInspectorOptions;
import org.apache.hadoop.hive.serde2.typeinfo.TypeInfo;
import org.apache.hadoop.hive.serde2.typeinfo.TypeInfoUtils;
import org.apache.hadoop.hive.ql.exec.Description;
@Description(name = "letters", value = "_FUNC_(expr) - Returns total number of letters in all the strings of a column.")
public class TotalNumOfLettersGenericUDAF extends AbstractGenericUDAFResolver {
@Override
public GenericUDAFEvaluator getEvaluator(TypeInfo[] parameters)
throws SemanticException {
if (parameters.length != 1) {
throw new UDFArgumentTypeException(parameters.length - 1,
"Exactly one argument is expected.");
}
ObjectInspector oi = TypeInfoUtils.getStandardJavaObjectInspectorFromTypeInfo(parameters[0]);
if (oi.getCategory() != ObjectInspector.Category.PRIMITIVE){
throw new UDFArgumentTypeException(0,
"Argument must be PRIMITIVE, but "
+ oi.getCategory().name()
+ " was passed.");
}
PrimitiveObjectInspector inputOI = (PrimitiveObjectInspector) oi;
if (inputOI.getPrimitiveCategory() != PrimitiveObjectInspector.PrimitiveCategory.STRING){
throw new UDFArgumentTypeException(0,
"Argument must be String, but "
+ inputOI.getPrimitiveCategory().name()
+ " was passed.");
}
return new TotalNumOfLettersEvaluator();
}
public static class TotalNumOfLettersEvaluator extends GenericUDAFEvaluator {
PrimitiveObjectInspector inputOI;
ObjectInspector outputOI;
PrimitiveObjectInspector integerOI;
int total = 0;
@Override
public ObjectInspector init(Mode m, ObjectInspector[] parameters)
throws HiveException {
assert (parameters.length == 1);
super.init(m, parameters);
// init input object inspectors
if (m == Mode.PARTIAL1 || m == Mode.COMPLETE) {
inputOI = (PrimitiveObjectInspector) parameters[0];
} else {
integerOI = (PrimitiveObjectInspector) parameters[0];
}
// init output object inspectors
// For partial function - array of integers
outputOI = ObjectInspectorFactory.getReflectionObjectInspector(Integer.class,
ObjectInspectorOptions.JAVA);
return outputOI;
}
/**
* class for storing the current sum of letters
*/
static class LetterSumAgg implements AggregationBuffer {
int sum = 0;
void add(int num){
sum += num;
}
}
@Override
public AggregationBuffer getNewAggregationBuffer() throws HiveException {
LetterSumAgg result = new LetterSumAgg();
return result;
}
@Override
public void reset(AggregationBuffer agg) throws HiveException {
LetterSumAgg myagg = new LetterSumAgg();
}
private boolean warned = false;
@Override
public void iterate(AggregationBuffer agg, Object[] parameters)
throws HiveException {
assert (parameters.length == 1);
if (parameters[0] != null) {
LetterSumAgg myagg = (LetterSumAgg) agg;
Object p1 = ((PrimitiveObjectInspector) inputOI).getPrimitiveJavaObject(parameters[0]);
myagg.add(String.valueOf(p1).length());
}
}
@Override
public Object terminatePartial(AggregationBuffer agg) throws HiveException {
LetterSumAgg myagg = (LetterSumAgg) agg;
total += myagg.sum;
return total;
}
@Override
public void merge(AggregationBuffer agg, Object partial)
throws HiveException {
if (partial != null) {
LetterSumAgg myagg1 = (LetterSumAgg) agg;
Integer partialSum = (Integer) integerOI.getPrimitiveJavaObject(partial);
LetterSumAgg myagg2 = new LetterSumAgg();
myagg2.add(partialSum);
myagg1.add(myagg2.sum);
}
}
@Override
public Object terminate(AggregationBuffer agg) throws HiveException {
LetterSumAgg myagg = (LetterSumAgg) agg;
total = myagg.sum;
return myagg.sum;
}
}
}
然后,一点一点拆解
getEvaluator
主要用来判定输入时的内容是否存在错误,返回正确的evaluator。evaluator()会返回udaf对数据处理结果。
parameters.length 判定输入参数的长度
oi获取输入参数,判定其是不是PRIMITIVE类型
然后再通过inputOI判定输入是否是一个String类型
最后返回了一个evaluator,名字和下面的类是一样的。
过程中运用到的类/接口:
TypeInfoUtils
getStandardJavaObjectInspectorFromTypeInfo方法会把输入的对象转换为标准对象
ObjectInspector
这里的Category定义了5种类型:基本类型(Primitive),集合(List),键值对映射(Map),结构体(Struct),联合体(Union)
更多关于这个接口的内容:https://blog.csdn.net/weixin_39469127/article/details/89739285
PrimitiveObjectInspector
这里可以看到基本类型包括哪些,byte\char\float\string等等
public GenericUDAFEvaluator getEvaluator(TypeInfo[] parameters)
throws SemanticException {
if (parameters.length != 1) {
throw new UDFArgumentTypeException(parameters.length - 1,
"Exactly one argument is expected.");
}
ObjectInspector oi = TypeInfoUtils.getStandardJavaObjectInspectorFromTypeInfo(parameters[0]);
if (oi.getCategory() != ObjectInspector.Category.PRIMITIVE){
throw new UDFArgumentTypeException(0,
"Argument must be PRIMITIVE, but "
+ oi.getCategory().name()
+ " was passed.");
}
PrimitiveObjectInspector inputOI = (PrimitiveObjectInspector) oi;
if (inputOI.getPrimitiveCategory() != PrimitiveObjectInspector.PrimitiveCategory.STRING){
throw new UDFArgumentTypeException(0,
"Argument must be String, but "
+ inputOI.getPrimitiveCategory().name()
+ " was passed.");
}
return new TotalNumOfLettersEvaluator();
}
TotalNumOfLettersEvaluator
首先定义了三个变量,类型可在上一部分的类/接口部分查到
不同的UDAF定义的变量可能有不同,根据需要进行定义。这些变量用到了哪个阶段,可以根据init()定义。
PrimitiveObjectInspector inputOI;
ObjectInspector outputOI;
PrimitiveObjectInspector integerOI;
init()
用来定义MR各个阶段中的输入输出
关于各个阶段,可以参考 https://blog.csdn.net/weixin_37766087/article/details/100940409
PARTIAL1: 这个是mapreduce的map阶段:从原始数据到部分数据聚合
PARTIAL2: 这个是mapreduce的map端的Combiner阶段,负责在map端合并map的数据::从部分数据聚合到部分数据聚合
FINAL: mapreduce的reduce阶段:从部分数据的聚合到完全聚合
COMPLETE: 如果出现了这个阶段,表示mapreduce只有map,没有reduce,所以map端就直接出结果了:从原始数据直接到完全聚合
在这个例子里inputOI被用在了map阶段,integerOI被用在了reduce和combiner阶段
用到的类:
ObjectInspectorFactory
是创建新ObjectInspector实例的主要方法
public ObjectInspector init(Mode m, ObjectInspector[] parameters)
throws HiveException {
assert (parameters.length == 1);
super.init(m, parameters);
// init input object inspectors
if (m == Mode.PARTIAL1 || m == Mode.COMPLETE) {
inputOI = (PrimitiveObjectInspector) parameters[0];
} else {
integerOI = (PrimitiveObjectInspector) parameters[0];
}
// init output object inspectors
// For partial function - array of integers
outputOI = ObjectInspectorFactory.getReflectionObjectInspector(Integer.class,
ObjectInspectorOptions.JAVA);
return outputOI;
}
LetterSumAgg
这个类是用来存放中间聚合结果的,都会实现AggregationBuffer接口,不过不一样的UDAF中间聚合结果的类名称、里面的内容肯定也各有不同。
这个累里有sum存放字符串中字符的个数,add方法用于进行累加计算
static class LetterSumAgg implements AggregationBuffer {
int sum = 0;
void add(int num){
sum += num;
}
}
getNewAggregationBuffer()
获得一个聚合的缓冲对象,每个map执行一次
https://www.cnblogs.com/yin-fei/p/10879731.html
public AggregationBuffer getNewAggregationBuffer() throws HiveException {
LetterSumAgg result = new LetterSumAgg();
return result;
}
reset()
用于重制聚合值,groupby的时候句可能用到,聚合的对象不一样了
iterate()
对每一行进行处理,这里可以看到构建了一个LetterSumAgg实例,并且将输入的内容进行了处理,获取其长度再进行累加
public void iterate(AggregationBuffer agg, Object[] parameters)
throws HiveException {
assert (parameters.length == 1);
if (parameters[0] != null) {
LetterSumAgg myagg = (LetterSumAgg) agg;
Object p1 = ((PrimitiveObjectInspector) inputOI).getPrimitiveJavaObject(parameters[0]);
myagg.add(String.valueOf(p1).length());
}
}
terminatePartial()
输出map/combiner结果,这里就是把中间对象的sum值加和起来
public Object terminatePartial(AggregationBuffer agg) throws HiveException {
LetterSumAgg myagg = (LetterSumAgg) agg;
total += myagg.sum;
return total;
}
merge()
一部分是agg,一部分是partial,partial就用到了integerOI,是部分聚合的字符个数的结果。merge会把两部分的内容进行汇总加和
public void merge(AggregationBuffer agg, Object partial)
throws HiveException {
if (partial != null) {
LetterSumAgg myagg1 = (LetterSumAgg) agg;
Integer partialSum = (Integer) integerOI.getPrimitiveJavaObject(partial);
LetterSumAgg myagg2 = new LetterSumAgg();
myagg2.add(partialSum);
myagg1.add(myagg2.sum);
}
}
terminate()
输出最后的结果
public Object terminate(AggregationBuffer agg) throws HiveException {
LetterSumAgg myagg = (LetterSumAgg) agg;
total = myagg.sum;
return myagg.sum;
}
源码2:实现collect_set
https://blog.csdn.net/xxydzyr/article/details/100975350
这段代码和实现字符串中字符个数累加的不同主要是其存储的是list,所以中间使用的类型不太一样,所以中间存储内容的类型需要根据情况去做调整
用到的类:
StandardListObjectInspector
处理存储为Java列表或Java数组对象的列表数据。始终使用ObjectInspectorFactory创建新的ObjectInspector对象,而不是直接创建此类的实例。
public ObjectInspector init(Mode m, ObjectInspector[] parameters) throws HiveException {
super.init(m, parameters);
if(m == Mode.PARTIAL1){
inputOI = (PrimitiveObjectInspector) parameters[0];
return ObjectInspectorFactory.getStandardListObjectInspector(
(PrimitiveObjectInspector) ObjectInspectorUtils.getStandardObjectInspector(inputOI)
);
}else{
if(!(parameters[0] instanceof StandardListObjectInspector)){
inputOI = (PrimitiveObjectInspector) ObjectInspectorUtils.
getStandardObjectInspector(parameters[0]);
return (PrimitiveObjectInspector) ObjectInspectorFactory.getStandardListObjectInspector(inputOI);
}else{
internalMergeOI = (StandardListObjectInspector) parameters[0];
inputOI = (PrimitiveObjectInspector) internalMergeOI.getListElementObjectInspector();
loi = (StandardListObjectInspector) ObjectInspectorUtils.getStandardObjectInspector(internalMergeOI);
return loi;
}
}
}
其他的地方都比较类似,inputOI处理在了putInfo里
注意terminate terminatalPartial输出的类型
网友评论