美文网首页spark生态系统大数据spark
[spark] Shuffle Write解析 (Sort Ba

[spark] Shuffle Write解析 (Sort Ba

作者: BIGUFO | 来源:发表于2017-11-09 20:12 被阅读206次

    本文基于 Spark 2.1 进行解析

    前言

    从 Spark 2.0 开始移除了Hash Based Shuffle,想要了解可参考Shuffle 过程,本文将讲解 Sort Based Shuffle。

    ShuffleMapTask的结果(ShuffleMapStage中FinalRDD的数据)都将写入磁盘,以供后续Stage拉取,即整个Shuffle包括前Stage的Shuffle Write和后Stage的Shuffle Read,由于内容较多,本文先解析Shuffle Write。

    概述:

    • 写records到内存缓冲区(一个数组维护的map),每次insert&update都需要检查是否达到溢写条件。
    • 若需要溢写,将集合中的数据根据partitionId和key(若需要)排序后顺序溢写到一个临时的磁盘文件,并释放内存新建一个map放数据,每次溢写都是写一个新的临时文件。
    • 一个task最终对应一个文件,将还在内存中的数据和已经spill的文件根据reduce端的partitionId进行合并,合并后需要再次聚合排序(有需要情况下),再根据partition的顺序写入最终文件,并返回每个partition在文件中的偏移量,最后以MapStatus对象返回给driver并注册到MapOutputTrackerMaster中,后续reduce好通过它来访问。

    入口

    执行一个ShuffleMapTask最终的执行逻辑是调用了ShuffleMapTask类
    的runTask()方法:

    override def runTask(context: TaskContext): MapStatus = {
        // Deserialize the RDD using the broadcast variable.
        val deserializeStartTime = System.currentTimeMillis()
        val ser = SparkEnv.get.closureSerializer.newInstance()
        // 从广播变量中反序列化出finalRDD和dependency
        val (rdd, dep) = ser.deserialize[(RDD[_], ShuffleDependency[_, _, _])](
          ByteBuffer.wrap(taskBinary.value), Thread.currentThread.getContextClassLoader)
        _executorDeserializeTime = System.currentTimeMillis() - deserializeStartTime
    
        var writer: ShuffleWriter[Any, Any] = null
        try {
          // 获取shuffleManager
          val manager = SparkEnv.get.shuffleManager
          // 通过shuffleManager的getWriter()方法,获得shuffle的writer
          writer = manager.getWriter[Any, Any](dep.shuffleHandle, partitionId, context)
           // 通过rdd指定分区的迭代器iterator方法来遍历每一条数据,再之上再调用writer的write方法以写数据
          writer.write(rdd.iterator(partition, context).asInstanceOf[Iterator[_ <: Product2[Any, Any]]])
          writer.stop(success = true).get
        } catch {
          case e: Exception =>
            try {
              if (writer != null) {
                writer.stop(success = false)
              }
            } catch {
              case e: Exception =>
                log.debug("Could not stop writer", e)
            }
            throw e
        }
      }
    

    其中的finalRDD和dependency是在Driver端DAGScheluer中提交Stage的时候加入广播变量的。

    接着通过SparkEnv获取shuffleManager,默认使用的是sort(对应的是org.apache.spark.shuffle.sort.SortShuffleManager),可通过spark.shuffle.manager设置。

    然后调用了manager.getWriter方法,该方法中检测到满足Unsafe Shuffle条件会自动采用Unsafe Shuffle,否则采用Sort Shuffle。使用Unsafe Shuffle有几个限制,shuffle阶段不能有aggregate操作,分区数不能超过一定大小( 2^24−1,这是可编码的最大parition id),所以像reduceByKey这类有aggregate操作的算子是不能使用Unsafe Shuffle。

    这里暂时讨论Sort Shuffle的情况,即getWriter返回的是SortShuffleWriter,我们直接看writer.write发生了什么:

    override def write(records: Iterator[Product2[K, V]]): Unit = {
        sorter = if (dep.mapSideCombine) {
          require(dep.aggregator.isDefined, "Map-side combine without Aggregator specified!")
          new ExternalSorter[K, V, C](
            context, dep.aggregator, Some(dep.partitioner), dep.keyOrdering, dep.serializer)
        } else {
          new ExternalSorter[K, V, V](
            context, aggregator = None, Some(dep.partitioner), ordering = None, dep.serializer)
        }
        // 写内存缓冲区,超过阈值则溢写到磁盘文件
        sorter.insertAll(records)
        // 获取该task的最终输出文件
        val output = shuffleBlockResolver.getDataFile(dep.shuffleId, mapId)
        val tmp = Utils.tempFileWith(output)
        try {
          val blockId = ShuffleBlockId(dep.shuffleId, mapId, IndexShuffleBlockResolver.NOOP_REDUCE_ID)
          // merge后写到data文件
          val partitionLengths = sorter.writePartitionedFile(blockId, tmp)
          // 写index文件
          shuffleBlockResolver.writeIndexFileAndCommit(dep.shuffleId, mapId, partitionLengths, tmp)
          mapStatus = MapStatus(blockManager.shuffleServerId, partitionLengths)
        } finally {
          if (tmp.exists() && !tmp.delete()) {
            logError(s"Error while deleting temp file ${tmp.getAbsolutePath}")
          }
        }
      }
    
    • 通过判断是否有map端的combine来创建不同的ExternalSorter,若有则将对应的aggregator和keyOrdering作为参数传入。
    • 调用sorter.insertAll(records),将records写入内存缓冲区,超过阈值则溢写到磁盘文件。
    • Merge内存记录和所有被spill到磁盘的文件,并写到最终的数据文件.data中。
    • 将每个partition的偏移量写到index文件中。

    先细看sorter.inster是怎么写到内存,并spill到磁盘文件的:

    def insertAll(records: Iterator[Product2[K, V]]): Unit = {
        // TODO: stop combining if we find that the reduction factor isn't high
        val shouldCombine = aggregator.isDefined
        // 若需要Combine
        if (shouldCombine) {
          // 获取对新value合并到聚合结果中的函数
          val mergeValue = aggregator.get.mergeValue
          // 获取创建初始聚合值的函数
          val createCombiner = aggregator.get.createCombiner
          var kv: Product2[K, V] = null
          // 通过mergeValue 对已有的聚合结果的新value进行合并,通过createCombiner 对没有聚合结果的新value初始化聚合结果
          val update = (hadValue: Boolean, oldValue: C) => {
            if (hadValue) mergeValue(oldValue, kv._2) else createCombiner(kv._2)
          }
          // 遍历records
          while (records.hasNext) {
            addElementsRead()
            kv = records.next()
            // 使用update函数进行value的聚合
            map.changeValue((getPartition(kv._1), kv._1), update)
            // 是否需要spill到磁盘文件
            maybeSpillCollection(usingMap = true)
          }
        // 不需要Combine
        } else {
          // Stick values into our buffer
          while (records.hasNext) {
            addElementsRead()
            val kv = records.next()
            buffer.insert(getPartition(kv._1), kv._1, kv._2.asInstanceOf[C])
            maybeSpillCollection(usingMap = false)
          }
        }
      }
    
    • 需要聚合的情况,遍历records拿到record的KV,通过map的changeValue方法并根据update函数来对相同K的V进行聚合,这里的map是PartitionedAppendOnlyMap类型,只能添加数据不能删除数据,底层实现是一个数组,数组中存KV键值对的方式是[K1,V1,K2,V2...],每一次操作后都会判断是否要spill到磁盘。

    • 不需要聚合的情况,直接将record放入buffer,然后判断是否要溢写到磁盘。

    先看map.changeValue方法到底是怎么通过map实现对数据combine的:

    override def changeValue(key: K, updateFunc: (Boolean, V) => V): V = {
        // 通过聚合算法得到newValue
        val newValue = super.changeValue(key, updateFunc)
        // 跟新对map的大小采样
        super.afterUpdate()
        newValue
      }
    

    super.changeValue的实现:

    def changeValue(key: K, updateFunc: (Boolean, V) => V): V = {
        ...
        // 根据k 得到pos
        var pos = rehash(k.hashCode) & mask
        var i = 1
        while (true) {
          // 从data中获取该位置的原来的key
          val curKey = data(2 * pos)  
          // 若原来的key和当前的key相等,则将两个值进行聚合
          if (k.eq(curKey) || k.equals(curKey)) {
            val newValue = updateFunc(true, data(2 * pos + 1).asInstanceOf[V])
            data(2 * pos + 1) = newValue.asInstanceOf[AnyRef]
            return newValue
           // 若当前key对应的位置没有key,则将当前key作为该位置的key
           // 并通过update方法初始化该位置的聚合结果
          } else if (curKey.eq(null)) {
            val newValue = updateFunc(false, null.asInstanceOf[V])
            data(2 * pos) = k
            data(2 * pos + 1) = newValue.asInstanceOf[AnyRef]
            // 扩容
            incrementSize()
            return newValue
          // 若对应位置有key但不和当前key相等,即hash冲突了,则继续向后遍历
          } else {
            val delta = i
            pos = (pos + delta) & mask
            i += 1
          }
        }
        null.asInstanceOf[V] // Never reached but needed to keep compiler happy
      }
    

    根据K的hashCode再哈希与上掩码 得到 pos,2 * pos 为 k 应该所在的位置,2 * pos + 1 为 k 对应的 v 所在的位置,获取k应该所在位置的原来的key:

    • 若原来的key和当前的 k 相等,则通过update函数将两个v进行聚合并更新该位置的value
    • 若原来的key存在但不和当前的k 相等,则说明hash冲突了,更新pos继续遍历
    • 若原来的key不存在,则将当前k作为该位置的key,并通过update函数初始化该k对应的聚合结果,接着会通过incrementSize()方法进行扩容:
       private def incrementSize() {
          curSize += 1
          if (curSize > growThreshold) {
            growTable()
          }
        }
      
      跟新curSize,若当前大小超过了阈值growThreshold(growThreshold是当前容量capacity的0.7倍),则通过growTable()来扩容:
    protected def growTable() {
        // 容量翻倍
        val newCapacity = capacity * 2
        require(newCapacity <= MAXIMUM_CAPACITY, s"Can't contain more than ${growThreshold} elements")
        //生成新的数组来存数据
        val newData = new Array[AnyRef](2 * newCapacity)
        val newMask = newCapacity - 1
        var oldPos = 0
        while (oldPos < capacity) {
          // 将旧数组中的数据重新计算位置放到新的数组中
          if (!data(2 * oldPos).eq(null)) {
            val key = data(2 * oldPos)
            val value = data(2 * oldPos + 1)
            var newPos = rehash(key.hashCode) & newMask
            var i = 1
            var keepGoing = true
            while (keepGoing) {
              val curKey = newData(2 * newPos)
              if (curKey.eq(null)) {
                newData(2 * newPos) = key
                newData(2 * newPos + 1) = value
                keepGoing = false
              } else {
                val delta = i
                newPos = (newPos + delta) & newMask
                i += 1
              }
            }
          }
          oldPos += 1
        }
        // 替换及跟新变量
        data = newData
        capacity = newCapacity
        mask = newMask
        growThreshold = (LOAD_FACTOR * newCapacity).toInt
      }
    

    这里重新创建了一个两倍capacity 的数组来存放数据,将原来数组中的数据通过重新计算位置放到新数组里,将data替换为新的数组,并跟新一些变量。

    此时聚合已经完成,回到changeValue方面里面,接下来会执行super.afterUpdate()方法来对map的大小进行采样:

    protected def afterUpdate(): Unit = {
        numUpdates += 1
        if (nextSampleNum == numUpdates) {
          takeSample()
        }
      }
    

    若每遍历跟新一条record,都来对map进行采样估计大小,假设采样一次需要1ms,100w次采样就会花上16.7分钟,性能大大降低。所以这里只有当update次数达到nextSampleNum 的时候才通过takeSample()采样一次:

    private def takeSample(): Unit = {
        samples.enqueue(Sample(SizeEstimator.estimate(this), numUpdates))
        // Only use the last two samples to extrapolate
        if (samples.size > 2) {
          samples.dequeue()
        }
        // 估计每次跟新的变化量
        val bytesDelta = samples.toList.reverse match {
          case latest :: previous :: tail =>
            (latest.size - previous.size).toDouble / (latest.numUpdates - previous.numUpdates)
          // If fewer than 2 samples, assume no change
          case _ => 0
        }
        // 跟新变化量
        bytesPerUpdate = math.max(0, bytesDelta)
        // 获取下次采样的次数
        nextSampleNum = math.ceil(numUpdates * SAMPLE_GROWTH_RATE).toLong
      }
    

    这里估计每次跟新的变化量的逻辑是:(当前map大小-上次采样的时候的大小) / (当前update的次数 - 上次采样的时候的update次数)。

    接着计算下次需要采样的update次数,该次数是指数级增长的,基数是1.1,第一次采样后,要1.1次进行第二次采样,第1.1*1.1次后进行第三次采样,以此类推,开始增长慢,后面增长跨度会非常大。

    这里采样完成后回到insetAll方法,接着通过maybeSpillCollection方法判断是否需要spill:

     private def maybeSpillCollection(usingMap: Boolean): Unit = {
        var estimatedSize = 0L
        if (usingMap) {
          estimatedSize = map.estimateSize()
          if (maybeSpill(map, estimatedSize)) {
            map = new PartitionedAppendOnlyMap[K, C]
          }
        } else {
          estimatedSize = buffer.estimateSize()
          if (maybeSpill(buffer, estimatedSize)) {
            buffer = new PartitionedPairBuffer[K, C]
          }
        }
    
        if (estimatedSize > _peakMemoryUsedBytes) {
          _peakMemoryUsedBytes = estimatedSize
        }
      }
    

    通过集合的estimateSize方法估计map的大小,若需要spill则将集合中的数据spill到磁盘文件,并且为集合创建一个新的对象放数据。先看看估计大小的方法estimateSize:

     def estimateSize(): Long = {
        assert(samples.nonEmpty)
        val extrapolatedDelta = bytesPerUpdate * (numUpdates - samples.last.numUpdates)
        (samples.last.size + extrapolatedDelta).toLong
      }
    

    以上次采样完更新的bytePerUpdate 作为最近平均每次跟新的大小,估计当前占用内存:(当前update次数-上次采样时的update次数) * 每次跟新大小 + 上次采样记录的大小。

    获取到当前集合的大小后调用maybeSpill判断是否需要spill:

    protected def maybeSpill(collection: C, currentMemory: Long): Boolean = {
        var shouldSpill = false
        if (elementsRead % 32 == 0 && currentMemory >= myMemoryThreshold) {
          // Claim up to double our current memory from the shuffle memory pool
          val amountToRequest = 2 * currentMemory - myMemoryThreshold
          val granted = acquireMemory(amountToRequest)
          // 跟新申请到的内存
          myMemoryThreshold += granted 
          // 集合大小还是比申请到的内存大?spill : no spill
          shouldSpill = currentMemory >= myMemoryThreshold
        }
        shouldSpill = shouldSpill || _elementsRead > numElementsForceSpillThreshold
        // Actually spill
        if (shouldSpill) {
          _spillCount += 1
          logSpillage(currentMemory)
          spill(collection)
          _elementsRead = 0
          _memoryBytesSpilled += currentMemory
          releaseMemory()
        }
        shouldSpill
      }
    

    这里有两种情况都可导致spill:

    • 当前集合包含的records数超过了 numElementsForceSpillThreshold(默认为Long.MaxValue,可通过spark.shuffle.spill.numElementsForceSpillThreshold设置)
    • 当前集合包含的records数为32的整数倍,并且当前集合的大小超过了申请的内存myMemoryThreshold(第一次申请默认为5 * 1024 * 1024,可通过spark.shuffle.spill.initialMemoryThreshold设置),此时并不会立即spill,会尝试申请更多的内存避免spill,这里尝试申请的内存为2倍集合大小减去当前已经申请的内存大小(实际申请到的内存为granted),若加上原来的内存还是比当前集合的大小要小则需要spill。

    若需要spill,则跟新spill次数,调用spill(collection)方法进行溢写磁盘,并释放内存。
    跟进spill方法看看其具体实现:

    override protected[this] def spill(collection: WritablePartitionedPairCollection[K, C]): Unit = {
        // 传入comparator将集合中的数据先根据partition排序再通过key排序后返回一个迭代器
        val inMemoryIterator = collection.destructiveSortedWritablePartitionedIterator(comparator)
        // 写到磁盘文件,并返回一个对该文件的描述对象SpilledFile
        val spillFile = spillMemoryIteratorToDisk(inMemoryIterator)
        // 添加到spill文件数组
        spills.append(spillFile)
      }
    

    继续跟进看看spillMemoryIteratorToDisk的实现:

    private[this] def spillMemoryIteratorToDisk(inMemoryIterator: WritablePartitionedIterator)
          : SpilledFile = {
        // 生成临时文件和blockId
        val (blockId, file) = diskBlockManager.createTempShuffleBlock()
    
        // 这些值在每次flush后会被重置
        var objectsWritten: Long = 0
        var spillMetrics: ShuffleWriteMetrics = null
        var writer: DiskBlockObjectWriter = null
        def openWriter(): Unit = {
          assert (writer == null && spillMetrics == null)
          spillMetrics = new ShuffleWriteMetrics
          writer = blockManager.getDiskWriter(blockId, file, serInstance, fileBufferSize, spillMetrics)
        }
        openWriter()
    
        // 按写入磁盘的顺序记录分支的大小
        val batchSizes = new ArrayBuffer[Long]
    
        // 记录每个分区有多少元素
        val elementsPerPartition = new Array[Long](numPartitions)
    
        // Flush  writer 内容到磁盘,并更新相关变量
        def flush(): Unit = {
          val w = writer
          writer = null
          w.commitAndClose()
          _diskBytesSpilled += spillMetrics.bytesWritten
          batchSizes.append(spillMetrics.bytesWritten)
          spillMetrics = null
          objectsWritten = 0
        }
    
        var success = false
        try {
          // 遍历迭代器
          while (inMemoryIterator.hasNext) {
            val partitionId = inMemoryIterator.nextPartition()
            require(partitionId >= 0 && partitionId < numPartitions,
              s"partition Id: ${partitionId} should be in the range [0, ${numPartitions})")
            inMemoryIterator.writeNext(writer)
            elementsPerPartition(partitionId) += 1
            objectsWritten += 1
            // 元素个数达到批量序列化大小则flush到磁盘
            if (objectsWritten == serializerBatchSize) {
              flush()
              openWriter()
            }
          }
          // 将剩余的数据flush
          if (objectsWritten > 0) {
            flush()
          } else if (writer != null) {
            val w = writer
            writer = null
            w.revertPartialWritesAndClose()
          }
          success = true
        } finally {
            ...
        }
        // 返回SpilledFile
        SpilledFile(file, blockId, batchSizes.toArray, elementsPerPartition)
      }
    

    通过diskBlockManager创建临时文件和blockID,临时文件名格式为是 "temp_shuffle_" + id,遍历内存数据迭代器,并调用Writer(DiskBlockObjectWriter)的write方法,当写的次数达到序列化大小则flush到磁盘文件,并重新打开writer,及跟新batchSizes等信息。

    最后返回一个SpilledFile对象,该对象包含了溢写的临时文件File,blockId,每次flush的到磁盘的大小,每个partition对应的数据条数。

    spill完成,并且insertAll方法也执行完成,回到开始的SortShuffleWriter的write方法:

    override def write(records: Iterator[Product2[K, V]]): Unit = {
        ...
        // 写内存缓冲区,超过阈值则溢写到磁盘文件
        sorter.insertAll(records)
        // 获取该task的最终输出文件
        val output = shuffleBlockResolver.getDataFile(dep.shuffleId, mapId)
        val tmp = Utils.tempFileWith(output)
        try {
          val blockId = ShuffleBlockId(dep.shuffleId, mapId, IndexShuffleBlockResolver.NOOP_REDUCE_ID)
          // merge后写到data文件
          val partitionLengths = sorter.writePartitionedFile(blockId, tmp)
          // 写index文件shuffleBlockResolver.writeIndexFileAndCommit(dep.shuffleId, mapId, partitionLengths, tmp)
          mapStatus = MapStatus(blockManager.shuffleServerId, partitionLengths)
        } finally {
          if (tmp.exists() && !tmp.delete()) {
            logError(s"Error while deleting temp file ${tmp.getAbsolutePath}")
          }
        }
      }
    

    获取最后的输出文件名及blockId,文件格式:

     "shuffle_" + shuffleId + "_" + mapId + "_" + reduceId + ".data"
    

    接着通过sorter.writePartitionedFile方法来写文件,其中包括内存及所有spill文件的merge操作,看看起具体实现:

    def writePartitionedFile(
          blockId: BlockId,
          outputFile: File): Array[Long] = {
    
        val writeMetrics = context.taskMetrics().shuffleWriteMetrics
    
        // 跟踪每个分区在文件中的range
        val lengths = new Array[Long](numPartitions)
        // 数据只存在内存中
        if (spills.isEmpty) { 
          val collection = if (aggregator.isDefined) map else buffer
          // 将内存中的数据先通过partitionId再通过k排序后返回一个迭代器
          val it = collection.destructiveSortedWritablePartitionedIterator(comparator)
          // 遍历数据写入磁盘
          while (it.hasNext) {
            val writer = blockManager.getDiskWriter(
              blockId, outputFile, serInstance, fileBufferSize, writeMetrics)
            val partitionId = it.nextPartition()
            //等待一个partition的数据写完后刷新到磁盘文件
            while (it.hasNext && it.nextPartition() == partitionId) {
              it.writeNext(writer)
            }
            writer.commitAndClose()
            val segment = writer.fileSegment()
            // 记录每个partition数据长度
            lengths(partitionId) = segment.length
          }
        } else {
          // 有数据spill到磁盘,先merge
          for ((id, elements) <- this.partitionedIterator) {
            if (elements.hasNext) {
              val writer = blockManager.getDiskWriter(
                blockId, outputFile, serInstance, fileBufferSize, writeMetrics)
              for (elem <- elements) {
                writer.write(elem._1, elem._2)
              }
              writer.commitAndClose()
              val segment = writer.fileSegment()
              lengths(id) = segment.length
            }
          }
        }
    
        context.taskMetrics().incMemoryBytesSpilled(memoryBytesSpilled)
        context.taskMetrics().incDiskBytesSpilled(diskBytesSpilled)
        context.taskMetrics().incPeakExecutionMemory(peakMemoryUsedBytes)
    
        lengths
      }
    
    • 数据只存在内存中而没有spill文件,根据传入的比较函数comparator来对集合里的数据先根据partition排序再对里面的key排序并返回一个迭代器,遍历该迭代器得到所有recored,每一个partition对应一个writer,一个partition的数据写完后再flush到磁盘文件,并记录该partition的数据长度。
    • 数据有spill文件,通过方法partitionedIterator对内存和spill文件的数据进行merge-sort后返回一个(partitionId,对应分区的数据的迭代器)的迭代器,也是一个partition对应一个Writer,写完一个partition再flush到磁盘,并记录该partition数据的长度。

    接下来看看通过this.partitionedIterator方法是怎么将内存及spill文件的数据进行merge-sort的:

    def partitionedIterator: Iterator[(Int, Iterator[Product2[K, C]])] = {
        val usingMap = aggregator.isDefined
        val collection: WritablePartitionedPairCollection[K, C] = if (usingMap) map else buffer
        if (spills.isEmpty) {
          if (!ordering.isDefined) {
            // 只根据partitionId排序,不需要对key排序
            groupByPartition(destructiveIterator(collection.partitionedDestructiveSortedIterator(None)))
          } else {
            // 需要对partitionID和key进行排序
            groupByPartition(destructiveIterator(
              collection.partitionedDestructiveSortedIterator(Some(keyComparator))))
          }
        } else {
          // Merge spilled and in-memory data
          merge(spills, destructiveIterator(
            collection.partitionedDestructiveSortedIterator(comparator)))
        }
      }
    

    这里在有spill文件的情况下会执行下面的merge方法,传入的是spill文件数组和内存中的数据进过partitionId和key排序后的数据迭代器,看看merge:

    private def merge(spills: Seq[SpilledFile], inMemory: Iterator[((Int, K), C)])
          : Iterator[(Int, Iterator[Product2[K, C]])] = {
        // 每个文件对应一个Reader
        val readers = spills.map(new SpillReader(_)) 
        val inMemBuffered = inMemory.buffered
        (0 until numPartitions).iterator.map { p =>
          // 获取内存中当前partition对应的Iterator
          val inMemIterator = new IteratorForPartition(p, inMemBuffered)
          // 将spill文件对应的partition的数据与内存中对应partition数据合并
          val iterators = readers.map(_.readNextPartition()) ++ Seq(inMemIterator)
          if (aggregator.isDefined) {
            // 对key进行聚合并排序
            (p, mergeWithAggregation(
              iterators, aggregator.get.mergeCombiners, keyComparator, ordering.isDefined))
          } else if (ordering.isDefined) {
            // 排序
            (p, mergeSort(iterators, ordering.get))
          } else {
            (p, iterators.iterator.flatten)
          }
        }
      }
    

    merge方法将属于同一个reduce端的partition的内存数据和spill文件数据合并起来,再进行聚合排序(有需要的话),最后返回(reduce对应的partitionId,该分区数据迭代器)

    将数据merge-sort后写入最终的文件后,需要将每个partition的偏移量持久化到文件以供后续每个reduce根据偏移量获取自己的数据,写偏移量的逻辑很简单,就是根据前面得到的partition长度的数组将偏移量写到index文件中:

    def writeIndexFileAndCommit(
          shuffleId: Int,
          mapId: Int,
          lengths: Array[Long],
          dataTmp: File): Unit = {
        val indexFile = getIndexFile(shuffleId, mapId)
        val indexTmp = Utils.tempFileWith(indexFile)
        try {
          val out = new DataOutputStream(new BufferedOutputStream(new FileOutputStream(indexTmp)))
          Utils.tryWithSafeFinally {
            // We take in lengths of each block, need to convert it to offsets.
            var offset = 0L
            out.writeLong(offset)
            for (length <- lengths) {
              offset += length
              out.writeLong(offset)
            }
          } 
        ......
        }
      }
    

    根据shuffleId和mapId获取index文件并创建一个写文件的文件流,按照reduce端partition对应的offset依次写到index文件中,如:
    0,
    length(partition1),
    length(partition1)+length(partition2),
    length(partition1)+length(partition2)+length(partition3)
    ...

    最后创建一个MapStatus实例返回,包含了reduce端每个partition对应的偏移量。

    该对象将返回到Driver端的DAGScheluer处理,被添加到对应stage的OutputLoc里,当该stage的所有task完成的时候会将这些结果注册到MapOutputTrackerMaster,以便下一个stage的task就可以通过它来获取shuffle的结果的元数据信息。

    至此Shuffle Write完成!

    Shuffle Read部分请看 Shuffle Read解析

    相关文章

      网友评论

      本文标题:[spark] Shuffle Write解析 (Sort Ba

      本文链接:https://www.haomeiwen.com/subject/gdpemxtx.html