序
本文主要研究一下flink Table的groupBy操作
Table.groupBy
flink-table_2.11-1.7.0-sources.jar!/org/apache/flink/table/api/table.scala
class Table(
private[flink] val tableEnv: TableEnvironment,
private[flink] val logicalPlan: LogicalNode) {
//......
def groupBy(fields: String): GroupedTable = {
val fieldsExpr = ExpressionParser.parseExpressionList(fields)
groupBy(fieldsExpr: _*)
}
def groupBy(fields: Expression*): GroupedTable = {
new GroupedTable(this, fields)
}
//......
}
- Table的groupBy操作支持两种参数,一种是String类型,一种是Expression类型;String参数的方法是将String转换为Expression,最后调用的Expression参数的groupBy方法,该方法创建了GroupedTable
GroupedTable
flink-table_2.11-1.7.0-sources.jar!/org/apache/flink/table/api/table.scala
class GroupedTable(
private[flink] val table: Table,
private[flink] val groupKey: Seq[Expression]) {
def select(fields: Expression*): Table = {
val expandedFields = expandProjectList(fields, table.logicalPlan, table.tableEnv)
val (aggNames, propNames) = extractAggregationsAndProperties(expandedFields, table.tableEnv)
if (propNames.nonEmpty) {
throw new ValidationException("Window properties can only be used on windowed tables.")
}
val projectsOnAgg = replaceAggregationsAndProperties(
expandedFields, table.tableEnv, aggNames, propNames)
val projectFields = extractFieldReferences(expandedFields ++ groupKey)
new Table(table.tableEnv,
Project(projectsOnAgg,
Aggregate(groupKey, aggNames.map(a => Alias(a._1, a._2)).toSeq,
Project(projectFields, table.logicalPlan).validate(table.tableEnv)
).validate(table.tableEnv)
).validate(table.tableEnv))
}
def select(fields: String): Table = {
val fieldExprs = ExpressionParser.parseExpressionList(fields)
//get the correct expression for AggFunctionCall
val withResolvedAggFunctionCall = fieldExprs.map(replaceAggFunctionCall(_, table.tableEnv))
select(withResolvedAggFunctionCall: _*)
}
}
- GroupedTable有两个属性,一个是原始的Table,一个是Seq[Expression]类型的groupKey
- GroupedTable提供两个select方法,参数类型分别为String、Expression,String类型的参数最后也是转为Expression类型
- select方法使用Project创建新的Table,而Project则是通过Aggregate来创建
Aggregate
flink-table_2.11-1.7.0-sources.jar!/org/apache/flink/table/plan/logical/operators.scala
case class Aggregate(
groupingExpressions: Seq[Expression],
aggregateExpressions: Seq[NamedExpression],
child: LogicalNode) extends UnaryNode {
override def output: Seq[Attribute] = {
(groupingExpressions ++ aggregateExpressions) map {
case ne: NamedExpression => ne.toAttribute
case e => Alias(e, e.toString).toAttribute
}
}
override protected[logical] def construct(relBuilder: RelBuilder): RelBuilder = {
child.construct(relBuilder)
relBuilder.aggregate(
relBuilder.groupKey(groupingExpressions.map(_.toRexNode(relBuilder)).asJava),
aggregateExpressions.map {
case Alias(agg: Aggregation, name, _) => agg.toAggCall(name)(relBuilder)
case _ => throw new RuntimeException("This should never happen.")
}.asJava)
}
override def validate(tableEnv: TableEnvironment): LogicalNode = {
implicit val relBuilder: RelBuilder = tableEnv.getRelBuilder
val resolvedAggregate = super.validate(tableEnv).asInstanceOf[Aggregate]
val groupingExprs = resolvedAggregate.groupingExpressions
val aggregateExprs = resolvedAggregate.aggregateExpressions
aggregateExprs.foreach(validateAggregateExpression)
groupingExprs.foreach(validateGroupingExpression)
def validateAggregateExpression(expr: Expression): Unit = expr match {
case distinctExpr: DistinctAgg =>
distinctExpr.child match {
case _: DistinctAgg => failValidation(
"Chained distinct operators are not supported!")
case aggExpr: Aggregation => validateAggregateExpression(aggExpr)
case _ => failValidation(
"Distinct operator can only be applied to aggregation expressions!")
}
// check aggregate function
case aggExpr: Aggregation
if aggExpr.getSqlAggFunction.requiresOver =>
failValidation(s"OVER clause is necessary for window functions: [${aggExpr.getClass}].")
// check no nested aggregation exists.
case aggExpr: Aggregation =>
aggExpr.children.foreach { child =>
child.preOrderVisit {
case agg: Aggregation =>
failValidation(
"It's not allowed to use an aggregate function as " +
"input of another aggregate function")
case _ => // OK
}
}
case a: Attribute if !groupingExprs.exists(_.checkEquals(a)) =>
failValidation(
s"expression '$a' is invalid because it is neither" +
" present in group by nor an aggregate function")
case e if groupingExprs.exists(_.checkEquals(e)) => // OK
case e => e.children.foreach(validateAggregateExpression)
}
def validateGroupingExpression(expr: Expression): Unit = {
if (!expr.resultType.isKeyType) {
failValidation(
s"expression $expr cannot be used as a grouping expression " +
"because it's not a valid key type which must be hashable and comparable")
}
}
resolvedAggregate
}
}
- Aggregate继承了UnaryNode,它接收三个参数,一个是Seq[Expression]类型的groupingExpressions,一个是Seq[NamedExpression]类型的aggregateExpressions,一个是LogicalNode类型的child;construct方法调用了relBuilder.aggregate,传入的RelBuilder.GroupKey参数是通过relBuilder.groupKey构建,而传入的RelBuilder.AggCall参数则是通过aggregateExpressions.map构造而来
RelBuilder.groupKey
calcite-core-1.18.0-sources.jar!/org/apache/calcite/tools/RelBuilder.java
public class RelBuilder {
protected final RelOptCluster cluster;
protected final RelOptSchema relOptSchema;
private final RelFactories.FilterFactory filterFactory;
private final RelFactories.ProjectFactory projectFactory;
private final RelFactories.AggregateFactory aggregateFactory;
private final RelFactories.SortFactory sortFactory;
private final RelFactories.ExchangeFactory exchangeFactory;
private final RelFactories.SortExchangeFactory sortExchangeFactory;
private final RelFactories.SetOpFactory setOpFactory;
private final RelFactories.JoinFactory joinFactory;
private final RelFactories.SemiJoinFactory semiJoinFactory;
private final RelFactories.CorrelateFactory correlateFactory;
private final RelFactories.ValuesFactory valuesFactory;
private final RelFactories.TableScanFactory scanFactory;
private final RelFactories.MatchFactory matchFactory;
private final Deque<Frame> stack = new ArrayDeque<>();
private final boolean simplify;
private final RexSimplify simplifier;
//......
/** Creates an empty group key. */
public GroupKey groupKey() {
return groupKey(ImmutableList.of());
}
/** Creates a group key. */
public GroupKey groupKey(RexNode... nodes) {
return groupKey(ImmutableList.copyOf(nodes));
}
/** Creates a group key. */
public GroupKey groupKey(Iterable<? extends RexNode> nodes) {
return new GroupKeyImpl(ImmutableList.copyOf(nodes), false, null, null);
}
/** Creates a group key with grouping sets. */
public GroupKey groupKey(Iterable<? extends RexNode> nodes,
Iterable<? extends Iterable<? extends RexNode>> nodeLists) {
return groupKey_(nodes, false, nodeLists);
}
/** Creates a group key of fields identified by ordinal. */
public GroupKey groupKey(int... fieldOrdinals) {
return groupKey(fields(ImmutableIntList.of(fieldOrdinals)));
}
/** Creates a group key of fields identified by name. */
public GroupKey groupKey(String... fieldNames) {
return groupKey(fields(ImmutableList.copyOf(fieldNames)));
}
public GroupKey groupKey(@Nonnull ImmutableBitSet groupSet) {
return groupKey(groupSet, ImmutableList.of(groupSet));
}
public GroupKey groupKey(ImmutableBitSet groupSet,
@Nonnull Iterable<? extends ImmutableBitSet> groupSets) {
return groupKey_(groupSet, false, ImmutableList.copyOf(groupSets));
}
private GroupKey groupKey_(ImmutableBitSet groupSet, boolean indicator,
@Nonnull ImmutableList<ImmutableBitSet> groupSets) {
if (groupSet.length() > peek().getRowType().getFieldCount()) {
throw new IllegalArgumentException("out of bounds: " + groupSet);
}
Objects.requireNonNull(groupSets);
final ImmutableList<RexNode> nodes =
fields(ImmutableIntList.of(groupSet.toArray()));
final List<ImmutableList<RexNode>> nodeLists =
Util.transform(groupSets,
bitSet -> fields(ImmutableIntList.of(bitSet.toArray())));
return groupKey_(nodes, indicator, nodeLists);
}
private GroupKey groupKey_(Iterable<? extends RexNode> nodes,
boolean indicator,
Iterable<? extends Iterable<? extends RexNode>> nodeLists) {
final ImmutableList.Builder<ImmutableList<RexNode>> builder =
ImmutableList.builder();
for (Iterable<? extends RexNode> nodeList : nodeLists) {
builder.add(ImmutableList.copyOf(nodeList));
}
return new GroupKeyImpl(ImmutableList.copyOf(nodes), indicator, builder.build(), null);
}
//......
}
- RelBuilder提供了诸多groupKey方法用于创建GroupKey,其最后调用的是私有方法groupKey_,该方法创建了GroupKeyImpl
GroupKey
calcite-core-1.18.0-sources.jar!/org/apache/calcite/tools/RelBuilder.java
public interface GroupKey {
/** Assigns an alias to this group key.
*
* <p>Used to assign field names in the {@code group} operation. */
GroupKey alias(String alias);
}
/** Implementation of {@link GroupKey}. */
protected static class GroupKeyImpl implements GroupKey {
final ImmutableList<RexNode> nodes;
final boolean indicator;
final ImmutableList<ImmutableList<RexNode>> nodeLists;
final String alias;
GroupKeyImpl(ImmutableList<RexNode> nodes, boolean indicator,
ImmutableList<ImmutableList<RexNode>> nodeLists, String alias) {
this.nodes = Objects.requireNonNull(nodes);
assert !indicator;
this.indicator = indicator;
this.nodeLists = nodeLists;
this.alias = alias;
}
@Override public String toString() {
return alias == null ? nodes.toString() : nodes + " as " + alias;
}
public GroupKey alias(String alias) {
return Objects.equals(this.alias, alias)
? this
: new GroupKeyImpl(nodes, indicator, nodeLists, alias);
}
}
- GroupKey接口定义了alias方法,用于给group操作的字段别名;GroupKeyImpl是GroupKey接口的实现类,其alias返回的是GroupKeyImpl
RelBuilder.aggregate
calcite-core-1.18.0-sources.jar!/org/apache/calcite/tools/RelBuilder.java
public class RelBuilder {
protected final RelOptCluster cluster;
protected final RelOptSchema relOptSchema;
private final RelFactories.FilterFactory filterFactory;
private final RelFactories.ProjectFactory projectFactory;
private final RelFactories.AggregateFactory aggregateFactory;
private final RelFactories.SortFactory sortFactory;
private final RelFactories.ExchangeFactory exchangeFactory;
private final RelFactories.SortExchangeFactory sortExchangeFactory;
private final RelFactories.SetOpFactory setOpFactory;
private final RelFactories.JoinFactory joinFactory;
private final RelFactories.SemiJoinFactory semiJoinFactory;
private final RelFactories.CorrelateFactory correlateFactory;
private final RelFactories.ValuesFactory valuesFactory;
private final RelFactories.TableScanFactory scanFactory;
private final RelFactories.MatchFactory matchFactory;
private final Deque<Frame> stack = new ArrayDeque<>();
private final boolean simplify;
private final RexSimplify simplifier;
//......
/** Creates an {@link Aggregate} with an array of
* calls. */
public RelBuilder aggregate(GroupKey groupKey, AggCall... aggCalls) {
return aggregate(groupKey, ImmutableList.copyOf(aggCalls));
}
public RelBuilder aggregate(GroupKey groupKey,
List<AggregateCall> aggregateCalls) {
return aggregate(groupKey,
Lists.transform(aggregateCalls, AggCallImpl2::new));
}
/** Creates an {@link Aggregate} with a list of
* calls. */
public RelBuilder aggregate(GroupKey groupKey, Iterable<AggCall> aggCalls) {
final Registrar registrar = new Registrar();
registrar.extraNodes.addAll(fields());
registrar.names.addAll(peek().getRowType().getFieldNames());
final GroupKeyImpl groupKey_ = (GroupKeyImpl) groupKey;
final ImmutableBitSet groupSet =
ImmutableBitSet.of(registrar.registerExpressions(groupKey_.nodes));
label:
if (Iterables.isEmpty(aggCalls) && !groupKey_.indicator) {
final RelMetadataQuery mq = peek().getCluster().getMetadataQuery();
if (groupSet.isEmpty()) {
final Double minRowCount = mq.getMinRowCount(peek());
if (minRowCount == null || minRowCount < 1D) {
// We can't remove "GROUP BY ()" if there's a chance the rel could be
// empty.
break label;
}
}
if (registrar.extraNodes.size() == fields().size()) {
final Boolean unique = mq.areColumnsUnique(peek(), groupSet);
if (unique != null && unique) {
// Rel is already unique.
return project(fields(groupSet.asList()));
}
}
final Double maxRowCount = mq.getMaxRowCount(peek());
if (maxRowCount != null && maxRowCount <= 1D) {
// If there is at most one row, rel is already unique.
return this;
}
}
final ImmutableList<ImmutableBitSet> groupSets;
if (groupKey_.nodeLists != null) {
final int sizeBefore = registrar.extraNodes.size();
final SortedSet<ImmutableBitSet> groupSetSet =
new TreeSet<>(ImmutableBitSet.ORDERING);
for (ImmutableList<RexNode> nodeList : groupKey_.nodeLists) {
final ImmutableBitSet groupSet2 =
ImmutableBitSet.of(registrar.registerExpressions(nodeList));
if (!groupSet.contains(groupSet2)) {
throw new IllegalArgumentException("group set element " + nodeList
+ " must be a subset of group key");
}
groupSetSet.add(groupSet2);
}
groupSets = ImmutableList.copyOf(groupSetSet);
if (registrar.extraNodes.size() > sizeBefore) {
throw new IllegalArgumentException(
"group sets contained expressions not in group key: "
+ registrar.extraNodes.subList(sizeBefore,
registrar.extraNodes.size()));
}
} else {
groupSets = ImmutableList.of(groupSet);
}
for (AggCall aggCall : aggCalls) {
if (aggCall instanceof AggCallImpl) {
final AggCallImpl aggCall1 = (AggCallImpl) aggCall;
registrar.registerExpressions(aggCall1.operands);
if (aggCall1.filter != null) {
registrar.registerExpression(aggCall1.filter);
}
}
}
project(registrar.extraNodes);
rename(registrar.names);
final Frame frame = stack.pop();
final RelNode r = frame.rel;
final List<AggregateCall> aggregateCalls = new ArrayList<>();
for (AggCall aggCall : aggCalls) {
final AggregateCall aggregateCall;
if (aggCall instanceof AggCallImpl) {
final AggCallImpl aggCall1 = (AggCallImpl) aggCall;
final List<Integer> args =
registrar.registerExpressions(aggCall1.operands);
final int filterArg = aggCall1.filter == null ? -1
: registrar.registerExpression(aggCall1.filter);
if (aggCall1.distinct && !aggCall1.aggFunction.isQuantifierAllowed()) {
throw new IllegalArgumentException("DISTINCT not allowed");
}
if (aggCall1.filter != null && !aggCall1.aggFunction.allowsFilter()) {
throw new IllegalArgumentException("FILTER not allowed");
}
RelCollation collation =
RelCollations.of(aggCall1.orderKeys
.stream()
.map(orderKey ->
collation(orderKey, RelFieldCollation.Direction.ASCENDING,
null, Collections.emptyList()))
.collect(Collectors.toList()));
aggregateCall =
AggregateCall.create(aggCall1.aggFunction, aggCall1.distinct,
aggCall1.approximate, args, filterArg, collation,
groupSet.cardinality(), r, null, aggCall1.alias);
} else {
aggregateCall = ((AggCallImpl2) aggCall).aggregateCall;
}
aggregateCalls.add(aggregateCall);
}
assert ImmutableBitSet.ORDERING.isStrictlyOrdered(groupSets) : groupSets;
for (ImmutableBitSet set : groupSets) {
assert groupSet.contains(set);
}
RelNode aggregate = aggregateFactory.createAggregate(r,
groupKey_.indicator, groupSet, groupSets, aggregateCalls);
// build field list
final ImmutableList.Builder<Field> fields = ImmutableList.builder();
final List<RelDataTypeField> aggregateFields =
aggregate.getRowType().getFieldList();
int i = 0;
// first, group fields
for (Integer groupField : groupSet.asList()) {
RexNode node = registrar.extraNodes.get(groupField);
final SqlKind kind = node.getKind();
switch (kind) {
case INPUT_REF:
fields.add(frame.fields.get(((RexInputRef) node).getIndex()));
break;
default:
String name = aggregateFields.get(i).getName();
RelDataTypeField fieldType =
new RelDataTypeFieldImpl(name, i, node.getType());
fields.add(new Field(ImmutableSet.of(), fieldType));
break;
}
i++;
}
// second, indicator fields (copy from aggregate rel type)
if (groupKey_.indicator) {
for (int j = 0; j < groupSet.cardinality(); ++j) {
final RelDataTypeField field = aggregateFields.get(i);
final RelDataTypeField fieldType =
new RelDataTypeFieldImpl(field.getName(), i, field.getType());
fields.add(new Field(ImmutableSet.of(), fieldType));
i++;
}
}
// third, aggregate fields. retain `i' as field index
for (int j = 0; j < aggregateCalls.size(); ++j) {
final AggregateCall call = aggregateCalls.get(j);
final RelDataTypeField fieldType =
new RelDataTypeFieldImpl(aggregateFields.get(i + j).getName(), i + j,
call.getType());
fields.add(new Field(ImmutableSet.of(), fieldType));
}
stack.push(new Frame(aggregate, fields.build()));
return this;
}
//......
}
- RelBuilder的aggregate操作接收两个参数,一个是GroupKey,一个是集合类型的AggCall;其中AggCall最后是转换为AggregateCall,然后通过aggregateFactory.createAggregate方法取出stack队首的Frame,创建新的RelNode,构造新的Frame,然后重新放入stack的队首
RelFactories.AggregateFactory.createAggregate
calcite-core-1.18.0-sources.jar!/org/apache/calcite/rel/core/RelFactories.java
public class RelFactories {
//......
public static final AggregateFactory DEFAULT_AGGREGATE_FACTORY =
new AggregateFactoryImpl();
public interface AggregateFactory {
/** Creates an aggregate. */
RelNode createAggregate(RelNode input, boolean indicator,
ImmutableBitSet groupSet, ImmutableList<ImmutableBitSet> groupSets,
List<AggregateCall> aggCalls);
}
private static class AggregateFactoryImpl implements AggregateFactory {
@SuppressWarnings("deprecation")
public RelNode createAggregate(RelNode input, boolean indicator,
ImmutableBitSet groupSet, ImmutableList<ImmutableBitSet> groupSets,
List<AggregateCall> aggCalls) {
return LogicalAggregate.create(input, indicator,
groupSet, groupSets, aggCalls);
}
}
//......
}
- RelFactories定义了AggregateFactory接口,该接口定义了createAggregate方法,用于将一系列的AggregateCall操作转为新的RelNode;AggregateFactoryImpl是AggregateFactory接口的实现类,它的createAggregate方法调用的是LogicalAggregate.create方法
LogicalAggregate.create
calcite-core-1.18.0-sources.jar!/org/apache/calcite/rel/logical/LogicalAggregate.java
public final class LogicalAggregate extends Aggregate {
//......
public static LogicalAggregate create(final RelNode input,
ImmutableBitSet groupSet, List<ImmutableBitSet> groupSets,
List<AggregateCall> aggCalls) {
return create_(input, false, groupSet, groupSets, aggCalls);
}
@Deprecated // to be removed before 2.0
public static LogicalAggregate create(final RelNode input,
boolean indicator,
ImmutableBitSet groupSet,
List<ImmutableBitSet> groupSets,
List<AggregateCall> aggCalls) {
return create_(input, indicator, groupSet, groupSets, aggCalls);
}
private static LogicalAggregate create_(final RelNode input,
boolean indicator,
ImmutableBitSet groupSet,
List<ImmutableBitSet> groupSets,
List<AggregateCall> aggCalls) {
final RelOptCluster cluster = input.getCluster();
final RelTraitSet traitSet = cluster.traitSetOf(Convention.NONE);
return new LogicalAggregate(cluster, traitSet, input, indicator, groupSet,
groupSets, aggCalls);
}
//......
}
- LogicalAggregate的create方法创建的是LogicalAggregate
小结
- Table的groupBy操作支持两种参数,一种是String类型,一种是Expression类型;String参数的方法是将String转换为Expression,最后调用的Expression参数的groupBy方法,该方法创建了GroupedTable
- GroupedTable有两个属性,一个是原始的Table,一个是Seq[Expression]类型的groupKey;它提供两个select方法,参数类型分别为String、Expression,String类型的参数最后也是转为Expression类型;select方法使用Project创建新的Table,而Project则是通过Aggregate来创建
- Aggregate继承了UnaryNode,它接收三个参数,一个是Seq[Expression]类型的groupingExpressions,一个是Seq[NamedExpression]类型的aggregateExpressions,一个是LogicalNode类型的child;construct方法调用了relBuilder.aggregate,传入的RelBuilder.GroupKey参数是通过relBuilder.groupKey构建,而传入的RelBuilder.AggCall参数则是通过aggregateExpressions.map构造而来
- RelBuilder的aggregate操作接收两个参数,一个是GroupKey(
GroupKey接口定义了alias方法,用于给group操作的字段别名;GroupKeyImpl是GroupKey接口的实现类,其alias返回的是GroupKeyImpl
),一个是集合类型的AggCall;其中AggCall最后是转换为AggregateCall,然后通过aggregateFactory.createAggregate方法取出stack队首的Frame,创建新的RelNode,构造新的Frame,然后重新放入stack的队首 - RelFactories定义了AggregateFactory接口,该接口定义了createAggregate方法,用于将一系列的AggregateCall操作转为新的RelNode;AggregateFactoryImpl是AggregateFactory接口的实现类,它的createAggregate方法调用的是LogicalAggregate.create方法,创建的是LogicalAggregate
网友评论