美文网首页
看懂UDAF写法

看懂UDAF写法

作者: kaiker | 来源:发表于2020-09-17 15:25 被阅读0次

UDAF 就是一个多行导成一行的聚合函数,它的过程与MR过程紧密结合


image.png

源码0:sum实现

https://www.cnblogs.com/yin-fei/p/10879731.html 这个例子比较简单


源码1:用于统计字符串字符个数的UDAF

https://github.com/rathboma/hive-extension-examples/blob/master/src/main/java/com/matthewrathbone/example/TotalNumOfLettersGenericUDAF.java

import org.apache.hadoop.hive.ql.exec.UDFArgumentTypeException;
import org.apache.hadoop.hive.ql.metadata.HiveException;
import org.apache.hadoop.hive.ql.parse.SemanticException;
import org.apache.hadoop.hive.ql.udf.generic.AbstractGenericUDAFResolver;
import org.apache.hadoop.hive.ql.udf.generic.GenericUDAFEvaluator;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory;
import org.apache.hadoop.hive.serde2.objectinspector.PrimitiveObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory.ObjectInspectorOptions;
import org.apache.hadoop.hive.serde2.typeinfo.TypeInfo;
import org.apache.hadoop.hive.serde2.typeinfo.TypeInfoUtils;
import org.apache.hadoop.hive.ql.exec.Description;

@Description(name = "letters", value = "_FUNC_(expr) - Returns total number of letters in all the strings of a column.")
public class TotalNumOfLettersGenericUDAF extends AbstractGenericUDAFResolver {

    @Override
    public GenericUDAFEvaluator getEvaluator(TypeInfo[] parameters)
            throws SemanticException {
        if (parameters.length != 1) {
            throw new UDFArgumentTypeException(parameters.length - 1,
                    "Exactly one argument is expected.");
        }
        
        ObjectInspector oi = TypeInfoUtils.getStandardJavaObjectInspectorFromTypeInfo(parameters[0]);
        
        if (oi.getCategory() != ObjectInspector.Category.PRIMITIVE){
            throw new UDFArgumentTypeException(0,
                            "Argument must be PRIMITIVE, but "
                            + oi.getCategory().name()
                            + " was passed.");
        }
        
        PrimitiveObjectInspector inputOI = (PrimitiveObjectInspector) oi;
        
        if (inputOI.getPrimitiveCategory() != PrimitiveObjectInspector.PrimitiveCategory.STRING){
            throw new UDFArgumentTypeException(0,
                            "Argument must be String, but "
                            + inputOI.getPrimitiveCategory().name()
                            + " was passed.");
        }
        
        return new TotalNumOfLettersEvaluator();
    }

    public static class TotalNumOfLettersEvaluator extends GenericUDAFEvaluator {

        PrimitiveObjectInspector inputOI;
        ObjectInspector outputOI;
        PrimitiveObjectInspector integerOI;
        
        int total = 0;

        @Override
        public ObjectInspector init(Mode m, ObjectInspector[] parameters)
                throws HiveException {
            
            assert (parameters.length == 1);
            super.init(m, parameters);
           
            // init input object inspectors
            if (m == Mode.PARTIAL1 || m == Mode.COMPLETE) {
                inputOI = (PrimitiveObjectInspector) parameters[0];
            } else {
                integerOI = (PrimitiveObjectInspector) parameters[0];
            }

            // init output object inspectors
            // For partial function - array of integers
            outputOI = ObjectInspectorFactory.getReflectionObjectInspector(Integer.class,
                    ObjectInspectorOptions.JAVA);
            return outputOI;

        }

        /**
         * class for storing the current sum of letters
         */
        static class LetterSumAgg implements AggregationBuffer {
            int sum = 0;
            void add(int num){
                sum += num;
            }
        }

        @Override
        public AggregationBuffer getNewAggregationBuffer() throws HiveException {
            LetterSumAgg result = new LetterSumAgg();
            return result;
        }

        @Override
        public void reset(AggregationBuffer agg) throws HiveException {
            LetterSumAgg myagg = new LetterSumAgg();
        }
        
        private boolean warned = false;

        @Override
        public void iterate(AggregationBuffer agg, Object[] parameters)
                throws HiveException {
            assert (parameters.length == 1);
            if (parameters[0] != null) {
                LetterSumAgg myagg = (LetterSumAgg) agg;
                Object p1 = ((PrimitiveObjectInspector) inputOI).getPrimitiveJavaObject(parameters[0]);
                myagg.add(String.valueOf(p1).length());
            }
        }

        @Override
        public Object terminatePartial(AggregationBuffer agg) throws HiveException {
            LetterSumAgg myagg = (LetterSumAgg) agg;
            total += myagg.sum;
            return total;
        }

        @Override
        public void merge(AggregationBuffer agg, Object partial)
                throws HiveException {
            if (partial != null) {
                
                LetterSumAgg myagg1 = (LetterSumAgg) agg;
                
                Integer partialSum = (Integer) integerOI.getPrimitiveJavaObject(partial);
                
                LetterSumAgg myagg2 = new LetterSumAgg();
                
                myagg2.add(partialSum);
                myagg1.add(myagg2.sum);
            }
        }

        @Override
        public Object terminate(AggregationBuffer agg) throws HiveException {
            LetterSumAgg myagg = (LetterSumAgg) agg;
            total = myagg.sum;
            return myagg.sum;
        }

    }
}

然后,一点一点拆解

getEvaluator

主要用来判定输入时的内容是否存在错误,返回正确的evaluator。evaluator()会返回udaf对数据处理结果。

parameters.length 判定输入参数的长度
oi获取输入参数,判定其是不是PRIMITIVE类型
然后再通过inputOI判定输入是否是一个String类型
最后返回了一个evaluator,名字和下面的类是一样的。

过程中运用到的类/接口:
TypeInfoUtils
getStandardJavaObjectInspectorFromTypeInfo方法会把输入的对象转换为标准对象
ObjectInspector
这里的Category定义了5种类型:基本类型(Primitive),集合(List),键值对映射(Map),结构体(Struct),联合体(Union)
更多关于这个接口的内容:https://blog.csdn.net/weixin_39469127/article/details/89739285
PrimitiveObjectInspector
这里可以看到基本类型包括哪些,byte\char\float\string等等

    public GenericUDAFEvaluator getEvaluator(TypeInfo[] parameters)
            throws SemanticException {
        if (parameters.length != 1) {
            throw new UDFArgumentTypeException(parameters.length - 1,
                    "Exactly one argument is expected.");
        }
        
        ObjectInspector oi = TypeInfoUtils.getStandardJavaObjectInspectorFromTypeInfo(parameters[0]);
        
        if (oi.getCategory() != ObjectInspector.Category.PRIMITIVE){
            throw new UDFArgumentTypeException(0,
                            "Argument must be PRIMITIVE, but "
                            + oi.getCategory().name()
                            + " was passed.");
        }
        
        PrimitiveObjectInspector inputOI = (PrimitiveObjectInspector) oi;
        
        if (inputOI.getPrimitiveCategory() != PrimitiveObjectInspector.PrimitiveCategory.STRING){
            throw new UDFArgumentTypeException(0,
                            "Argument must be String, but "
                            + inputOI.getPrimitiveCategory().name()
                            + " was passed.");
        }
        
        return new TotalNumOfLettersEvaluator();
    }

TotalNumOfLettersEvaluator

首先定义了三个变量,类型可在上一部分的类/接口部分查到

不同的UDAF定义的变量可能有不同,根据需要进行定义。这些变量用到了哪个阶段,可以根据init()定义。

PrimitiveObjectInspector inputOI;
ObjectInspector outputOI;
PrimitiveObjectInspector integerOI;

init()

用来定义MR各个阶段中的输入输出
关于各个阶段,可以参考 https://blog.csdn.net/weixin_37766087/article/details/100940409

PARTIAL1: 这个是mapreduce的map阶段:从原始数据到部分数据聚合
PARTIAL2: 这个是mapreduce的map端的Combiner阶段,负责在map端合并map的数据::从部分数据聚合到部分数据聚合
FINAL: mapreduce的reduce阶段:从部分数据的聚合到完全聚合
COMPLETE: 如果出现了这个阶段,表示mapreduce只有map,没有reduce,所以map端就直接出结果了:从原始数据直接到完全聚合

在这个例子里inputOI被用在了map阶段,integerOI被用在了reduce和combiner阶段

用到的类:
ObjectInspectorFactory
是创建新ObjectInspector实例的主要方法

      public ObjectInspector init(Mode m, ObjectInspector[] parameters)
                throws HiveException {
            
            assert (parameters.length == 1);
            super.init(m, parameters);
           
            // init input object inspectors
            if (m == Mode.PARTIAL1 || m == Mode.COMPLETE) {
                inputOI = (PrimitiveObjectInspector) parameters[0];
            } else {
                integerOI = (PrimitiveObjectInspector) parameters[0];
            }

            // init output object inspectors
            // For partial function - array of integers
            outputOI = ObjectInspectorFactory.getReflectionObjectInspector(Integer.class,
                    ObjectInspectorOptions.JAVA);
            return outputOI;

        }

LetterSumAgg

这个类是用来存放中间聚合结果的,都会实现AggregationBuffer接口,不过不一样的UDAF中间聚合结果的类名称、里面的内容肯定也各有不同。
这个累里有sum存放字符串中字符的个数,add方法用于进行累加计算

static class LetterSumAgg implements AggregationBuffer {
            int sum = 0;
            void add(int num){
                sum += num;
            }
        }

getNewAggregationBuffer()

获得一个聚合的缓冲对象,每个map执行一次
https://www.cnblogs.com/yin-fei/p/10879731.html

public AggregationBuffer getNewAggregationBuffer() throws HiveException {
            LetterSumAgg result = new LetterSumAgg();
            return result;
        }

reset()

用于重制聚合值,groupby的时候句可能用到,聚合的对象不一样了

iterate()

对每一行进行处理,这里可以看到构建了一个LetterSumAgg实例,并且将输入的内容进行了处理,获取其长度再进行累加

public void iterate(AggregationBuffer agg, Object[] parameters)
                throws HiveException {
            assert (parameters.length == 1);
            if (parameters[0] != null) {
                LetterSumAgg myagg = (LetterSumAgg) agg;
                Object p1 = ((PrimitiveObjectInspector) inputOI).getPrimitiveJavaObject(parameters[0]);
                myagg.add(String.valueOf(p1).length());
            }
        }

terminatePartial()

输出map/combiner结果,这里就是把中间对象的sum值加和起来

public Object terminatePartial(AggregationBuffer agg) throws HiveException {
            LetterSumAgg myagg = (LetterSumAgg) agg;
            total += myagg.sum;
            return total;
        }

merge()

一部分是agg,一部分是partial,partial就用到了integerOI,是部分聚合的字符个数的结果。merge会把两部分的内容进行汇总加和

public void merge(AggregationBuffer agg, Object partial)
                throws HiveException {
            if (partial != null) {
                
                LetterSumAgg myagg1 = (LetterSumAgg) agg;
                
                Integer partialSum = (Integer) integerOI.getPrimitiveJavaObject(partial);
                
                LetterSumAgg myagg2 = new LetterSumAgg();
                
                myagg2.add(partialSum);
                myagg1.add(myagg2.sum);
            }
        }

terminate()

输出最后的结果

public Object terminate(AggregationBuffer agg) throws HiveException {
            LetterSumAgg myagg = (LetterSumAgg) agg;
            total = myagg.sum;
            return myagg.sum;
        }

源码2:实现collect_set

https://blog.csdn.net/xxydzyr/article/details/100975350

这段代码和实现字符串中字符个数累加的不同主要是其存储的是list,所以中间使用的类型不太一样,所以中间存储内容的类型需要根据情况去做调整

用到的类:
StandardListObjectInspector
处理存储为Java列表或Java数组对象的列表数据。始终使用ObjectInspectorFactory创建新的ObjectInspector对象,而不是直接创建此类的实例。

     public ObjectInspector init(Mode m, ObjectInspector[] parameters) throws HiveException {
            super.init(m, parameters);

            if(m == Mode.PARTIAL1){
                inputOI = (PrimitiveObjectInspector) parameters[0];
                return ObjectInspectorFactory.getStandardListObjectInspector(
                        (PrimitiveObjectInspector) ObjectInspectorUtils.getStandardObjectInspector(inputOI)
                );
            }else{
                if(!(parameters[0] instanceof StandardListObjectInspector)){
                    inputOI = (PrimitiveObjectInspector) ObjectInspectorUtils.
                            getStandardObjectInspector(parameters[0]);
                    return (PrimitiveObjectInspector) ObjectInspectorFactory.getStandardListObjectInspector(inputOI);
                }else{
                    internalMergeOI = (StandardListObjectInspector) parameters[0];
                    inputOI = (PrimitiveObjectInspector) internalMergeOI.getListElementObjectInspector();
                    loi = (StandardListObjectInspector) ObjectInspectorUtils.getStandardObjectInspector(internalMergeOI);
                    return loi;
                }
            }
        }

其他的地方都比较类似,inputOI处理在了putInfo里
注意terminate terminatalPartial输出的类型

相关文章

  • 看懂UDAF写法

    UDAF 就是一个多行导成一行的聚合函数,它的过程与MR过程紧密结合 源码0:sum实现 https://www....

  • spark的UDAF使用

    什么是UDAF? UDAF(User Defined Aggregate Function),即用户定义的聚合函数...

  • 看懂UDTF写法

    UDTF函数,表生成函数,他可以把一行打成多行多列,也可以打成一行多列,一列多行。比起UDAF,UDTF更好理解一...

  • 大数据面试题搜集--持续更新

    #Hadoop HIVEUDF UDTF UDAF UDF:单行进入,单行输出UDAF:多行进入,单行输出UDTF...

  • Hive的Transform和UDF

    UDTF Hive中UDTF编写和使用 UDAF Hive udaf开发入门和运行过程详解 Hive通用型自定义聚...

  • Spark: Custom UDAF Example

    Spark: Custom UDAF Example mark:https://ragrawal.wordpres...

  • Python If-else 多种写法让你看懂大佬代码

    Python If-else 多种写法让你看懂大佬代码 第一种:普通写法 第二种:常见一行表达式 为真时放if前 ...

  • Hive-UDAF

    UDAF 前两节分别介绍了基础UDF和UDTF,这一节我们将介绍最复杂的用户自定义聚合函数(UDAF)。用户自定义...

  • 生而为人,我也很抱歉。

    《人间失格》到了尾声了,之前没有看懂的地方也忽然就看懂了。他们认为这本书的写法为“丧”但是我并不觉得他丧,我觉得这...

  • Spark SQL 中 UDF 和 UDAF 的使用

    Spark SQL 支持 Hive 的 UDF(User defined functions) 和 UDAF(Us...

网友评论

      本文标题:看懂UDAF写法

      本文链接:https://www.haomeiwen.com/subject/lnqqyktx.html