美文网首页
spark aggregate & treeAggregate

spark aggregate & treeAggregate

作者: _zzzZzzz_ | 来源:发表于2019-01-25 18:50 被阅读0次

    aggregate和treeAggregate都是org.apache.spark.rdd包下的RDD类的方法。

    aggregate

    首先来看这个方法的签名

    abstract class RDD[T: ClassTag](
        @transient private var _sc: SparkContext,
        @transient private var deps: Seq[Dependency[_]]
      ) extends Serializable with Logging {
    ...
    ...
    /**
       * Aggregate the elements of each partition, and then the results for all the partitions, using
       * given combine functions and a neutral "zero value". This function can return a different result
       * type, U, than the type of this RDD, T. Thus, we need one operation for merging a T into an U
       * and one operation for merging two U's, as in scala.TraversableOnce. Both of these functions are
       * allowed to modify and return their first argument instead of creating a new U to avoid memory
       * allocation.
       *
       * @param zeroValue the initial value for the accumulated result of each partition for the
       *                  `seqOp` operator, and also the initial value for the combine results from
       *                  different partitions for the `combOp` operator - this will typically be the
       *                  neutral element (e.g. `Nil` for list concatenation or `0` for summation)
       * @param seqOp an operator used to accumulate results within a partition
       * @param combOp an associative operator used to combine results from different partitions
       */
      def aggregate[U: ClassTag](zeroValue: U)(
          seqOp: (U, T) => U, 
          combOp: (U, U) => U): U = withScope {...}
    ...
    ...
    }
    

    可以看到aggregate方法接受三个参数:aggregate(zeroValue)(seqOp, combOp),这里参数被分在两组括号中,这个写法涉及到了柯里化(Currying),感兴趣的同学可以去进一步了解一下。

    OK,现在简单翻译一下这个方法的注释:

    该方法(function)首先对每个partition的元素执行聚合(aggregate)操作,然后对所有partition的结果再次执行聚合操作。
    聚合操作使用了传入参数中的combOp作为聚合函数(combine functions),使用zeroValue作为聚合操作中的零元
    本方法返回值的类型U可不同于所属RDD对象的类型:T。
    因此我们需要一个函数来将一个T对象转换(merge into)为一个U对象,以及一个函数来将两个U对象合并(merge)为一个U对象。如同scala.TraversableOnce一样。这两个方法都支持对其第一个入参修改,从而避免频繁申请空间创建新的U对象。

    zeroValue : 零元,seqOp方法的初始值,也是combOp方法的初始值(如list拼接中的Nil、加法中的0
    seqOp : 单partition做聚合操作的方法
    combOp : 多partition之间做合并的方法

    下面来看一个具体的例子(参考了这篇回答):

    scala> val listRDD = spark.sparkContext.parallelize(Seq(1,2,3,4), 2)
    scala> def seqOp(localResult: Seq[Int], listElement: Int) = {Seq(localResult(0) + listElement, localResult(1) + 1) }
    scala> def combOp(localResultA: Seq[Int], localResultB: Seq[Int]) = {Seq(localResultA(0)+localResultB(0), localResultA(1)+localResultB(1))}
    scala> listRDD.aggregate(Seq(0, 0))(seqOp, combOp)
    res1: Seq[Int] = List(10, 4)
    

    这里新建了一个序列Seq(1, 2, 3, 4),并划分到两个partition中:

    partition0: Seq(1, 2)
    partition1: Seq(3, 4)
    

    最终想统计一个数对Seq(序列的和, 序列元素个数)。序列的和为1+2+3+4=10,序列个数显然是4
    计算方法如下:

    1. 对每个partition:
      a. 初始化聚合结果为Seq(0, 0)
      b. 对当前partition的序列元素,依次执行聚合操作seqOp
      c. 得到当前partition的聚合结果Seq(partition_sum, partition_count)

    2. 对所有partition:
      a. 依次合并各partition的聚合结果,合并方法为combOp
      b. 得到合并结果Seq(total_sum, total_count)

    计算过程如下图所示:

                (0, 0) <-- zeroValue
    
    [1, 2]                  [3, 4]
    
    0 + 1 = 1               0 + 3 = 3
    0 + 1 = 1               0 + 1 = 1
    
    1 + 2 = 3               3 + 4 = 7
    1 + 1 = 2               1 + 1 = 2       
        |                       |
        v                       v
      (3, 2)                  (7, 2)
          \                    / 
            \                /
              \            /
               ------------
               |  combOp  |
               ------------
                    |
                    v
                 (10, 4)
    

    这里需要注意,当我们把zeroValue改为(1, 0)的时候,我们其实是无法通过上面的图示来预期结果的,结果并不一定会变为(12, 4),因为在spark内部计算的时候,可能会多次使用该值做初始化。因此在选择zeroValue的时候应谨慎。
    OK,到此为止我们就了解了aggregate的使用方法,下面来看treeAggregate

    treeAggregate

    /**
       * Aggregates the elements of this RDD in a multi-level tree pattern.
       * This method is semantically identical to [[org.apache.spark.rdd.RDD#aggregate]].
       *
       * @param depth suggested depth of the tree (default: 2)
       */
      def treeAggregate[U: ClassTag](zeroValue: U)(
          seqOp: (U, T) => U,
          combOp: (U, U) => U,
          depth: Int = 2): U = withScope {...}
    

    其实基本上和aggregate是一样的,但是在aggregate中,需要把各partition的结果汇总发到driver上使用combOp进行最后一步汇总合并,这里有时会成为瓶颈(带宽、依次遍历各partition结果并合并),而treeAggregate就是用来优化这一环节的,按照树结构来reduce,提升性能。

    treeAggregate提供了一个新的参数depth,就是用来指定这个reduce树的深度的,默认为2。

    了解了aggregatetreeAggregate后,我们就知道了,在实际使用中,尽量还是使用treeAggregate吧。

    相关文章

      网友评论

          本文标题:spark aggregate & treeAggregate

          本文链接:https://www.haomeiwen.com/subject/wyqvjqtx.html