有以下Hive表的定义:
create table topic_recommend_score (
category_id int,
topic_id bigint,
score double,
rank int
);
这张表是我们业务里话题推荐分值表的简化版本。category_id代表分类ID,topic_id是话题ID,score是评分值。rank代表每个分类下话题分值的排名,用开窗函数计算出来的:
row_number() over(partition by t.category_id order by t.score desc)
在对外提供推荐结果时,我们会将每个小组下排名前1000的话题ID取出,拼成一个逗号分隔的字符串,处理之后送入HBase供调用方查询。拼合的SQL语句如下:
select category_id,
concat_ws(',',collect_list(cast(topic_id as string)))
from topic_recommend_score
where rank >= 1 and rank <= 1000
group by category_id;
看起来没什么问题?但实际上是错误的。输出结果中总会有一些category_id对应的列表顺序异常,比如本来排名正数与排名倒数的两批ID调换了位置,即rank变成了n-3, n-2, n-1, n, 5, 6, 7, ..., n-4, 1, 2, 3, 4
。
产生这个问题的根本原因自然在MapReduce,如果启动了多于一个mapper/reducer来处理数据,select出来的数据顺序就几乎肯定与原始顺序不同了。考虑把mapper数固定成1比较麻烦(见我之前写的那篇Hive调优文章),也不现实,所以要迂回地解决问题:把rank加进来再进行一次排序,拼接完之后把rank去掉。如下:
select category_id,
regexp_replace(
concat_ws(',',
sort_array(
collect_list(
concat_ws(':',lpad(cast(rank as string),5,'0'),cast(topic_id as string))
)
)
),
'\\d+\:','')
from topic_recommend_score
where rank >= 1 and rank <= 1000
group by category_id;
这里将rank放在了topic_id之前,用冒号分隔,然后用sort_array函数对collect_list之后的结果进行排序(只支持升序)。特别注意,rank必须要在高位补足够的0对齐,因为排序的是字符串而不是数字,如果不补0的话,按字典序排序就会变成1, 10, 11, 12, 13, 2, 3, 4...
,又不对了。
将排序的结果拼起来之后,用regexp_replace函数替换掉冒号及其前面的数字,大功告成。
顺便看一下Hive源码中collect_list和collect_set函数对应的逻辑吧。
public class GenericUDAFMkCollectionEvaluator extends GenericUDAFEvaluator
implements Serializable {
private static final long serialVersionUID = 1l;
enum BufferType { SET, LIST }
// For PARTIAL1 and COMPLETE: ObjectInspectors for original data
private transient PrimitiveObjectInspector inputOI;
// For PARTIAL2 and FINAL: ObjectInspectors for partial aggregations (list
// of objs)
private transient StandardListObjectInspector loi;
private transient ListObjectInspector internalMergeOI;
private BufferType bufferType;
//needed by kyro
public GenericUDAFMkCollectionEvaluator() {
}
public GenericUDAFMkCollectionEvaluator(BufferType bufferType){
this.bufferType = bufferType;
}
@Override
public ObjectInspector init(Mode m, ObjectInspector[] parameters)
throws HiveException {
super.init(m, parameters);
// init output object inspectors
// The output of a partial aggregation is a list
if (m == Mode.PARTIAL1) {
inputOI = (PrimitiveObjectInspector) parameters[0];
return ObjectInspectorFactory
.getStandardListObjectInspector((PrimitiveObjectInspector) ObjectInspectorUtils
.getStandardObjectInspector(inputOI));
} else {
if (!(parameters[0] instanceof ListObjectInspector)) {
//no map aggregation.
inputOI = (PrimitiveObjectInspector) ObjectInspectorUtils
.getStandardObjectInspector(parameters[0]);
return (StandardListObjectInspector) ObjectInspectorFactory
.getStandardListObjectInspector(inputOI);
} else {
internalMergeOI = (ListObjectInspector) parameters[0];
inputOI = (PrimitiveObjectInspector) internalMergeOI.getListElementObjectInspector();
loi = (StandardListObjectInspector) ObjectInspectorUtils.getStandardObjectInspector(internalMergeOI);
return loi;
}
}
}
class MkArrayAggregationBuffer extends AbstractAggregationBuffer {
private Collection<Object> container;
public MkArrayAggregationBuffer() {
if (bufferType == BufferType.LIST){
container = new ArrayList<Object>();
} else if(bufferType == BufferType.SET){
container = new LinkedHashSet<Object>();
} else {
throw new RuntimeException("Buffer type unknown");
}
}
}
@Override
public void reset(AggregationBuffer agg) throws HiveException {
((MkArrayAggregationBuffer) agg).container.clear();
}
@Override
public AggregationBuffer getNewAggregationBuffer() throws HiveException {
MkArrayAggregationBuffer ret = new MkArrayAggregationBuffer();
return ret;
}
//mapside
@Override
public void iterate(AggregationBuffer agg, Object[] parameters)
throws HiveException {
assert (parameters.length == 1);
Object p = parameters[0];
if (p != null) {
MkArrayAggregationBuffer myagg = (MkArrayAggregationBuffer) agg;
putIntoCollection(p, myagg);
}
}
//mapside
@Override
public Object terminatePartial(AggregationBuffer agg) throws HiveException {
MkArrayAggregationBuffer myagg = (MkArrayAggregationBuffer) agg;
List<Object> ret = new ArrayList<Object>(myagg.container.size());
ret.addAll(myagg.container);
return ret;
}
@Override
public void merge(AggregationBuffer agg, Object partial)
throws HiveException {
MkArrayAggregationBuffer myagg = (MkArrayAggregationBuffer) agg;
List<Object> partialResult = (ArrayList<Object>) internalMergeOI.getList(partial);
if (partialResult != null) {
for(Object i : partialResult) {
putIntoCollection(i, myagg);
}
}
}
@Override
public Object terminate(AggregationBuffer agg) throws HiveException {
MkArrayAggregationBuffer myagg = (MkArrayAggregationBuffer) agg;
List<Object> ret = new ArrayList<Object>(myagg.container.size());
ret.addAll(myagg.container);
return ret;
}
private void putIntoCollection(Object p, MkArrayAggregationBuffer myagg) {
Object pCopy = ObjectInspectorUtils.copyToStandardObject(p, this.inputOI);
myagg.container.add(pCopy);
}
public BufferType getBufferType() {
return bufferType;
}
public void setBufferType(BufferType bufferType) {
this.bufferType = bufferType;
}
}
网友评论