美文网首页
聊聊flink Table的groupBy操作

聊聊flink Table的groupBy操作

作者: go4it | 来源:发表于2019-01-25 10:16 被阅读23次

    本文主要研究一下flink Table的groupBy操作

    Table.groupBy

    flink-table_2.11-1.7.0-sources.jar!/org/apache/flink/table/api/table.scala

    class Table(
        private[flink] val tableEnv: TableEnvironment,
        private[flink] val logicalPlan: LogicalNode) {
    
      //......
    
      def groupBy(fields: String): GroupedTable = {
        val fieldsExpr = ExpressionParser.parseExpressionList(fields)
        groupBy(fieldsExpr: _*)
      }
    
      def groupBy(fields: Expression*): GroupedTable = {
        new GroupedTable(this, fields)
      }
    
      //......
    }
    
    • Table的groupBy操作支持两种参数,一种是String类型,一种是Expression类型;String参数的方法是将String转换为Expression,最后调用的Expression参数的groupBy方法,该方法创建了GroupedTable

    GroupedTable

    flink-table_2.11-1.7.0-sources.jar!/org/apache/flink/table/api/table.scala

    class GroupedTable(
      private[flink] val table: Table,
      private[flink] val groupKey: Seq[Expression]) {
    
      def select(fields: Expression*): Table = {
        val expandedFields = expandProjectList(fields, table.logicalPlan, table.tableEnv)
        val (aggNames, propNames) = extractAggregationsAndProperties(expandedFields, table.tableEnv)
        if (propNames.nonEmpty) {
          throw new ValidationException("Window properties can only be used on windowed tables.")
        }
    
        val projectsOnAgg = replaceAggregationsAndProperties(
          expandedFields, table.tableEnv, aggNames, propNames)
        val projectFields = extractFieldReferences(expandedFields ++ groupKey)
    
        new Table(table.tableEnv,
          Project(projectsOnAgg,
            Aggregate(groupKey, aggNames.map(a => Alias(a._1, a._2)).toSeq,
              Project(projectFields, table.logicalPlan).validate(table.tableEnv)
            ).validate(table.tableEnv)
          ).validate(table.tableEnv))
      }
    
      def select(fields: String): Table = {
        val fieldExprs = ExpressionParser.parseExpressionList(fields)
        //get the correct expression for AggFunctionCall
        val withResolvedAggFunctionCall = fieldExprs.map(replaceAggFunctionCall(_, table.tableEnv))
        select(withResolvedAggFunctionCall: _*)
      }
    }
    
    • GroupedTable有两个属性,一个是原始的Table,一个是Seq[Expression]类型的groupKey
    • GroupedTable提供两个select方法,参数类型分别为String、Expression,String类型的参数最后也是转为Expression类型
    • select方法使用Project创建新的Table,而Project则是通过Aggregate来创建

    Aggregate

    flink-table_2.11-1.7.0-sources.jar!/org/apache/flink/table/plan/logical/operators.scala

    case class Aggregate(
        groupingExpressions: Seq[Expression],
        aggregateExpressions: Seq[NamedExpression],
        child: LogicalNode) extends UnaryNode {
    
      override def output: Seq[Attribute] = {
        (groupingExpressions ++ aggregateExpressions) map {
          case ne: NamedExpression => ne.toAttribute
          case e => Alias(e, e.toString).toAttribute
        }
      }
    
      override protected[logical] def construct(relBuilder: RelBuilder): RelBuilder = {
        child.construct(relBuilder)
        relBuilder.aggregate(
          relBuilder.groupKey(groupingExpressions.map(_.toRexNode(relBuilder)).asJava),
          aggregateExpressions.map {
            case Alias(agg: Aggregation, name, _) => agg.toAggCall(name)(relBuilder)
            case _ => throw new RuntimeException("This should never happen.")
          }.asJava)
      }
    
      override def validate(tableEnv: TableEnvironment): LogicalNode = {
        implicit val relBuilder: RelBuilder = tableEnv.getRelBuilder
        val resolvedAggregate = super.validate(tableEnv).asInstanceOf[Aggregate]
        val groupingExprs = resolvedAggregate.groupingExpressions
        val aggregateExprs = resolvedAggregate.aggregateExpressions
        aggregateExprs.foreach(validateAggregateExpression)
        groupingExprs.foreach(validateGroupingExpression)
    
        def validateAggregateExpression(expr: Expression): Unit = expr match {
          case distinctExpr: DistinctAgg =>
            distinctExpr.child match {
              case _: DistinctAgg => failValidation(
                "Chained distinct operators are not supported!")
              case aggExpr: Aggregation => validateAggregateExpression(aggExpr)
              case _ => failValidation(
                "Distinct operator can only be applied to aggregation expressions!")
            }
          // check aggregate function
          case aggExpr: Aggregation
            if aggExpr.getSqlAggFunction.requiresOver =>
            failValidation(s"OVER clause is necessary for window functions: [${aggExpr.getClass}].")
          // check no nested aggregation exists.
          case aggExpr: Aggregation =>
            aggExpr.children.foreach { child =>
              child.preOrderVisit {
                case agg: Aggregation =>
                  failValidation(
                    "It's not allowed to use an aggregate function as " +
                      "input of another aggregate function")
                case _ => // OK
              }
            }
          case a: Attribute if !groupingExprs.exists(_.checkEquals(a)) =>
            failValidation(
              s"expression '$a' is invalid because it is neither" +
                " present in group by nor an aggregate function")
          case e if groupingExprs.exists(_.checkEquals(e)) => // OK
          case e => e.children.foreach(validateAggregateExpression)
        }
    
        def validateGroupingExpression(expr: Expression): Unit = {
          if (!expr.resultType.isKeyType) {
            failValidation(
              s"expression $expr cannot be used as a grouping expression " +
                "because it's not a valid key type which must be hashable and comparable")
          }
        }
        resolvedAggregate
      }
    }
    
    • Aggregate继承了UnaryNode,它接收三个参数,一个是Seq[Expression]类型的groupingExpressions,一个是Seq[NamedExpression]类型的aggregateExpressions,一个是LogicalNode类型的child;construct方法调用了relBuilder.aggregate,传入的RelBuilder.GroupKey参数是通过relBuilder.groupKey构建,而传入的RelBuilder.AggCall参数则是通过aggregateExpressions.map构造而来

    RelBuilder.groupKey

    calcite-core-1.18.0-sources.jar!/org/apache/calcite/tools/RelBuilder.java

    public class RelBuilder {
      protected final RelOptCluster cluster;
      protected final RelOptSchema relOptSchema;
      private final RelFactories.FilterFactory filterFactory;
      private final RelFactories.ProjectFactory projectFactory;
      private final RelFactories.AggregateFactory aggregateFactory;
      private final RelFactories.SortFactory sortFactory;
      private final RelFactories.ExchangeFactory exchangeFactory;
      private final RelFactories.SortExchangeFactory sortExchangeFactory;
      private final RelFactories.SetOpFactory setOpFactory;
      private final RelFactories.JoinFactory joinFactory;
      private final RelFactories.SemiJoinFactory semiJoinFactory;
      private final RelFactories.CorrelateFactory correlateFactory;
      private final RelFactories.ValuesFactory valuesFactory;
      private final RelFactories.TableScanFactory scanFactory;
      private final RelFactories.MatchFactory matchFactory;
      private final Deque<Frame> stack = new ArrayDeque<>();
      private final boolean simplify;
      private final RexSimplify simplifier;
    
      //......
    
      /** Creates an empty group key. */
      public GroupKey groupKey() {
        return groupKey(ImmutableList.of());
      }
    
      /** Creates a group key. */
      public GroupKey groupKey(RexNode... nodes) {
        return groupKey(ImmutableList.copyOf(nodes));
      }
    
      /** Creates a group key. */
      public GroupKey groupKey(Iterable<? extends RexNode> nodes) {
        return new GroupKeyImpl(ImmutableList.copyOf(nodes), false, null, null);
      }
    
      /** Creates a group key with grouping sets. */
      public GroupKey groupKey(Iterable<? extends RexNode> nodes,
          Iterable<? extends Iterable<? extends RexNode>> nodeLists) {
        return groupKey_(nodes, false, nodeLists);
      }
    
      /** Creates a group key of fields identified by ordinal. */
      public GroupKey groupKey(int... fieldOrdinals) {
        return groupKey(fields(ImmutableIntList.of(fieldOrdinals)));
      }
    
      /** Creates a group key of fields identified by name. */
      public GroupKey groupKey(String... fieldNames) {
        return groupKey(fields(ImmutableList.copyOf(fieldNames)));
      }
    
      public GroupKey groupKey(@Nonnull ImmutableBitSet groupSet) {
        return groupKey(groupSet, ImmutableList.of(groupSet));
      }
    
      public GroupKey groupKey(ImmutableBitSet groupSet,
          @Nonnull Iterable<? extends ImmutableBitSet> groupSets) {
        return groupKey_(groupSet, false, ImmutableList.copyOf(groupSets));
      }
    
      private GroupKey groupKey_(ImmutableBitSet groupSet, boolean indicator,
          @Nonnull ImmutableList<ImmutableBitSet> groupSets) {
        if (groupSet.length() > peek().getRowType().getFieldCount()) {
          throw new IllegalArgumentException("out of bounds: " + groupSet);
        }
        Objects.requireNonNull(groupSets);
        final ImmutableList<RexNode> nodes =
            fields(ImmutableIntList.of(groupSet.toArray()));
        final List<ImmutableList<RexNode>> nodeLists =
            Util.transform(groupSets,
                bitSet -> fields(ImmutableIntList.of(bitSet.toArray())));
        return groupKey_(nodes, indicator, nodeLists);
      }
    
      private GroupKey groupKey_(Iterable<? extends RexNode> nodes,
          boolean indicator,
          Iterable<? extends Iterable<? extends RexNode>> nodeLists) {
        final ImmutableList.Builder<ImmutableList<RexNode>> builder =
            ImmutableList.builder();
        for (Iterable<? extends RexNode> nodeList : nodeLists) {
          builder.add(ImmutableList.copyOf(nodeList));
        }
        return new GroupKeyImpl(ImmutableList.copyOf(nodes), indicator, builder.build(), null);
      }
    
      //......
    }
    
    • RelBuilder提供了诸多groupKey方法用于创建GroupKey,其最后调用的是私有方法groupKey_,该方法创建了GroupKeyImpl

    GroupKey

    calcite-core-1.18.0-sources.jar!/org/apache/calcite/tools/RelBuilder.java

      public interface GroupKey {
        /** Assigns an alias to this group key.
         *
         * <p>Used to assign field names in the {@code group} operation. */
        GroupKey alias(String alias);
      }
    
      /** Implementation of {@link GroupKey}. */
      protected static class GroupKeyImpl implements GroupKey {
        final ImmutableList<RexNode> nodes;
        final boolean indicator;
        final ImmutableList<ImmutableList<RexNode>> nodeLists;
        final String alias;
    
        GroupKeyImpl(ImmutableList<RexNode> nodes, boolean indicator,
            ImmutableList<ImmutableList<RexNode>> nodeLists, String alias) {
          this.nodes = Objects.requireNonNull(nodes);
          assert !indicator;
          this.indicator = indicator;
          this.nodeLists = nodeLists;
          this.alias = alias;
        }
    
        @Override public String toString() {
          return alias == null ? nodes.toString() : nodes + " as " + alias;
        }
    
        public GroupKey alias(String alias) {
          return Objects.equals(this.alias, alias)
              ? this
              : new GroupKeyImpl(nodes, indicator, nodeLists, alias);
        }
      }
    
    • GroupKey接口定义了alias方法,用于给group操作的字段别名;GroupKeyImpl是GroupKey接口的实现类,其alias返回的是GroupKeyImpl

    RelBuilder.aggregate

    calcite-core-1.18.0-sources.jar!/org/apache/calcite/tools/RelBuilder.java

    public class RelBuilder {
      protected final RelOptCluster cluster;
      protected final RelOptSchema relOptSchema;
      private final RelFactories.FilterFactory filterFactory;
      private final RelFactories.ProjectFactory projectFactory;
      private final RelFactories.AggregateFactory aggregateFactory;
      private final RelFactories.SortFactory sortFactory;
      private final RelFactories.ExchangeFactory exchangeFactory;
      private final RelFactories.SortExchangeFactory sortExchangeFactory;
      private final RelFactories.SetOpFactory setOpFactory;
      private final RelFactories.JoinFactory joinFactory;
      private final RelFactories.SemiJoinFactory semiJoinFactory;
      private final RelFactories.CorrelateFactory correlateFactory;
      private final RelFactories.ValuesFactory valuesFactory;
      private final RelFactories.TableScanFactory scanFactory;
      private final RelFactories.MatchFactory matchFactory;
      private final Deque<Frame> stack = new ArrayDeque<>();
      private final boolean simplify;
      private final RexSimplify simplifier;
    
      //......
    
      /** Creates an {@link Aggregate} with an array of
       * calls. */
      public RelBuilder aggregate(GroupKey groupKey, AggCall... aggCalls) {
        return aggregate(groupKey, ImmutableList.copyOf(aggCalls));
      }
    
      public RelBuilder aggregate(GroupKey groupKey,
          List<AggregateCall> aggregateCalls) {
        return aggregate(groupKey,
            Lists.transform(aggregateCalls, AggCallImpl2::new));
      }
    
      /** Creates an {@link Aggregate} with a list of
       * calls. */
      public RelBuilder aggregate(GroupKey groupKey, Iterable<AggCall> aggCalls) {
        final Registrar registrar = new Registrar();
        registrar.extraNodes.addAll(fields());
        registrar.names.addAll(peek().getRowType().getFieldNames());
        final GroupKeyImpl groupKey_ = (GroupKeyImpl) groupKey;
        final ImmutableBitSet groupSet =
            ImmutableBitSet.of(registrar.registerExpressions(groupKey_.nodes));
      label:
        if (Iterables.isEmpty(aggCalls) && !groupKey_.indicator) {
          final RelMetadataQuery mq = peek().getCluster().getMetadataQuery();
          if (groupSet.isEmpty()) {
            final Double minRowCount = mq.getMinRowCount(peek());
            if (minRowCount == null || minRowCount < 1D) {
              // We can't remove "GROUP BY ()" if there's a chance the rel could be
              // empty.
              break label;
            }
          }
          if (registrar.extraNodes.size() == fields().size()) {
            final Boolean unique = mq.areColumnsUnique(peek(), groupSet);
            if (unique != null && unique) {
              // Rel is already unique.
              return project(fields(groupSet.asList()));
            }
          }
          final Double maxRowCount = mq.getMaxRowCount(peek());
          if (maxRowCount != null && maxRowCount <= 1D) {
            // If there is at most one row, rel is already unique.
            return this;
          }
        }
        final ImmutableList<ImmutableBitSet> groupSets;
        if (groupKey_.nodeLists != null) {
          final int sizeBefore = registrar.extraNodes.size();
          final SortedSet<ImmutableBitSet> groupSetSet =
              new TreeSet<>(ImmutableBitSet.ORDERING);
          for (ImmutableList<RexNode> nodeList : groupKey_.nodeLists) {
            final ImmutableBitSet groupSet2 =
                ImmutableBitSet.of(registrar.registerExpressions(nodeList));
            if (!groupSet.contains(groupSet2)) {
              throw new IllegalArgumentException("group set element " + nodeList
                  + " must be a subset of group key");
            }
            groupSetSet.add(groupSet2);
          }
          groupSets = ImmutableList.copyOf(groupSetSet);
          if (registrar.extraNodes.size() > sizeBefore) {
            throw new IllegalArgumentException(
                "group sets contained expressions not in group key: "
                    + registrar.extraNodes.subList(sizeBefore,
                    registrar.extraNodes.size()));
          }
        } else {
          groupSets = ImmutableList.of(groupSet);
        }
        for (AggCall aggCall : aggCalls) {
          if (aggCall instanceof AggCallImpl) {
            final AggCallImpl aggCall1 = (AggCallImpl) aggCall;
            registrar.registerExpressions(aggCall1.operands);
            if (aggCall1.filter != null) {
              registrar.registerExpression(aggCall1.filter);
            }
          }
        }
        project(registrar.extraNodes);
        rename(registrar.names);
        final Frame frame = stack.pop();
        final RelNode r = frame.rel;
        final List<AggregateCall> aggregateCalls = new ArrayList<>();
        for (AggCall aggCall : aggCalls) {
          final AggregateCall aggregateCall;
          if (aggCall instanceof AggCallImpl) {
            final AggCallImpl aggCall1 = (AggCallImpl) aggCall;
            final List<Integer> args =
                registrar.registerExpressions(aggCall1.operands);
            final int filterArg = aggCall1.filter == null ? -1
                : registrar.registerExpression(aggCall1.filter);
            if (aggCall1.distinct && !aggCall1.aggFunction.isQuantifierAllowed()) {
              throw new IllegalArgumentException("DISTINCT not allowed");
            }
            if (aggCall1.filter != null && !aggCall1.aggFunction.allowsFilter()) {
              throw new IllegalArgumentException("FILTER not allowed");
            }
            RelCollation collation =
                RelCollations.of(aggCall1.orderKeys
                    .stream()
                    .map(orderKey ->
                        collation(orderKey, RelFieldCollation.Direction.ASCENDING,
                            null, Collections.emptyList()))
                    .collect(Collectors.toList()));
            aggregateCall =
                AggregateCall.create(aggCall1.aggFunction, aggCall1.distinct,
                    aggCall1.approximate, args, filterArg, collation,
                    groupSet.cardinality(), r, null, aggCall1.alias);
          } else {
            aggregateCall = ((AggCallImpl2) aggCall).aggregateCall;
          }
          aggregateCalls.add(aggregateCall);
        }
    
        assert ImmutableBitSet.ORDERING.isStrictlyOrdered(groupSets) : groupSets;
        for (ImmutableBitSet set : groupSets) {
          assert groupSet.contains(set);
        }
        RelNode aggregate = aggregateFactory.createAggregate(r,
            groupKey_.indicator, groupSet, groupSets, aggregateCalls);
    
        // build field list
        final ImmutableList.Builder<Field> fields = ImmutableList.builder();
        final List<RelDataTypeField> aggregateFields =
            aggregate.getRowType().getFieldList();
        int i = 0;
        // first, group fields
        for (Integer groupField : groupSet.asList()) {
          RexNode node = registrar.extraNodes.get(groupField);
          final SqlKind kind = node.getKind();
          switch (kind) {
          case INPUT_REF:
            fields.add(frame.fields.get(((RexInputRef) node).getIndex()));
            break;
          default:
            String name = aggregateFields.get(i).getName();
            RelDataTypeField fieldType =
                new RelDataTypeFieldImpl(name, i, node.getType());
            fields.add(new Field(ImmutableSet.of(), fieldType));
            break;
          }
          i++;
        }
        // second, indicator fields (copy from aggregate rel type)
        if (groupKey_.indicator) {
          for (int j = 0; j < groupSet.cardinality(); ++j) {
            final RelDataTypeField field = aggregateFields.get(i);
            final RelDataTypeField fieldType =
                new RelDataTypeFieldImpl(field.getName(), i, field.getType());
            fields.add(new Field(ImmutableSet.of(), fieldType));
            i++;
          }
        }
        // third, aggregate fields. retain `i' as field index
        for (int j = 0; j < aggregateCalls.size(); ++j) {
          final AggregateCall call = aggregateCalls.get(j);
          final RelDataTypeField fieldType =
              new RelDataTypeFieldImpl(aggregateFields.get(i + j).getName(), i + j,
                  call.getType());
          fields.add(new Field(ImmutableSet.of(), fieldType));
        }
        stack.push(new Frame(aggregate, fields.build()));
        return this;
      }
    
      //......
    }
    
    • RelBuilder的aggregate操作接收两个参数,一个是GroupKey,一个是集合类型的AggCall;其中AggCall最后是转换为AggregateCall,然后通过aggregateFactory.createAggregate方法取出stack队首的Frame,创建新的RelNode,构造新的Frame,然后重新放入stack的队首

    RelFactories.AggregateFactory.createAggregate

    calcite-core-1.18.0-sources.jar!/org/apache/calcite/rel/core/RelFactories.java

    public class RelFactories {
      //......
      public static final AggregateFactory DEFAULT_AGGREGATE_FACTORY =
          new AggregateFactoryImpl();
    
      public interface AggregateFactory {
        /** Creates an aggregate. */
        RelNode createAggregate(RelNode input, boolean indicator,
            ImmutableBitSet groupSet, ImmutableList<ImmutableBitSet> groupSets,
            List<AggregateCall> aggCalls);
      }
    
      private static class AggregateFactoryImpl implements AggregateFactory {
        @SuppressWarnings("deprecation")
        public RelNode createAggregate(RelNode input, boolean indicator,
            ImmutableBitSet groupSet, ImmutableList<ImmutableBitSet> groupSets,
            List<AggregateCall> aggCalls) {
          return LogicalAggregate.create(input, indicator,
              groupSet, groupSets, aggCalls);
        }
      }
    
      //......
    }
    
    • RelFactories定义了AggregateFactory接口,该接口定义了createAggregate方法,用于将一系列的AggregateCall操作转为新的RelNode;AggregateFactoryImpl是AggregateFactory接口的实现类,它的createAggregate方法调用的是LogicalAggregate.create方法

    LogicalAggregate.create

    calcite-core-1.18.0-sources.jar!/org/apache/calcite/rel/logical/LogicalAggregate.java

    public final class LogicalAggregate extends Aggregate {
      //......
    
      public static LogicalAggregate create(final RelNode input,
          ImmutableBitSet groupSet, List<ImmutableBitSet> groupSets,
          List<AggregateCall> aggCalls) {
        return create_(input, false, groupSet, groupSets, aggCalls);
      }
    
      @Deprecated // to be removed before 2.0
      public static LogicalAggregate create(final RelNode input,
          boolean indicator,
          ImmutableBitSet groupSet,
          List<ImmutableBitSet> groupSets,
          List<AggregateCall> aggCalls) {
        return create_(input, indicator, groupSet, groupSets, aggCalls);
      }
    
      private static LogicalAggregate create_(final RelNode input,
          boolean indicator,
          ImmutableBitSet groupSet,
          List<ImmutableBitSet> groupSets,
          List<AggregateCall> aggCalls) {
        final RelOptCluster cluster = input.getCluster();
        final RelTraitSet traitSet = cluster.traitSetOf(Convention.NONE);
        return new LogicalAggregate(cluster, traitSet, input, indicator, groupSet,
            groupSets, aggCalls);
      }
    
      //......
    }
    
    • LogicalAggregate的create方法创建的是LogicalAggregate

    小结

    • Table的groupBy操作支持两种参数,一种是String类型,一种是Expression类型;String参数的方法是将String转换为Expression,最后调用的Expression参数的groupBy方法,该方法创建了GroupedTable
    • GroupedTable有两个属性,一个是原始的Table,一个是Seq[Expression]类型的groupKey;它提供两个select方法,参数类型分别为String、Expression,String类型的参数最后也是转为Expression类型;select方法使用Project创建新的Table,而Project则是通过Aggregate来创建
    • Aggregate继承了UnaryNode,它接收三个参数,一个是Seq[Expression]类型的groupingExpressions,一个是Seq[NamedExpression]类型的aggregateExpressions,一个是LogicalNode类型的child;construct方法调用了relBuilder.aggregate,传入的RelBuilder.GroupKey参数是通过relBuilder.groupKey构建,而传入的RelBuilder.AggCall参数则是通过aggregateExpressions.map构造而来
    • RelBuilder的aggregate操作接收两个参数,一个是GroupKey(GroupKey接口定义了alias方法,用于给group操作的字段别名;GroupKeyImpl是GroupKey接口的实现类,其alias返回的是GroupKeyImpl),一个是集合类型的AggCall;其中AggCall最后是转换为AggregateCall,然后通过aggregateFactory.createAggregate方法取出stack队首的Frame,创建新的RelNode,构造新的Frame,然后重新放入stack的队首
    • RelFactories定义了AggregateFactory接口,该接口定义了createAggregate方法,用于将一系列的AggregateCall操作转为新的RelNode;AggregateFactoryImpl是AggregateFactory接口的实现类,它的createAggregate方法调用的是LogicalAggregate.create方法,创建的是LogicalAggregate

    doc

    相关文章

      网友评论

          本文标题:聊聊flink Table的groupBy操作

          本文链接:https://www.haomeiwen.com/subject/rfqmjqtx.html