[spark] Shuffle Read解析 (Sort Bas

作者: BIGUFO | 来源:发表于2017-11-15 00:04 被阅读118次

    Shuffle Write 请看 Shuffle Write解析

    本文将讲解shuffle Reduce部分,shuffle的下游Stage的第一个rdd是ShuffleRDD,通过其compute方法来获取上游Stage Shuffle Write溢写到磁盘文件数据的一个迭代器:

     override def compute(split: Partition, context: TaskContext): Iterator[(K, C)] = {
        val dep = dependencies.head.asInstanceOf[ShuffleDependency[K, V, C]]
        SparkEnv.get.shuffleManager.getReader(dep.shuffleHandle, split.index, split.index + 1, context)
          .read()
          .asInstanceOf[Iterator[(K, C)]]
      }
    

    从SparkEnv中获取shuffleManager(这里是SortShuffleManager),通过manager获取Reader并调用其read方法来得到一个迭代器。

    override def getReader[K, C](
          handle: ShuffleHandle,
          startPartition: Int,
          endPartition: Int,
          context: TaskContext): ShuffleReader[K, C] = {
        new BlockStoreShuffleReader(
          handle.asInstanceOf[BaseShuffleHandle[K, _, C]], startPartition, endPartition, context)
      }
    

    getReader方法实例化了一个BlockStoreShuffleReader,参数有需要获取分区对应的partitionId,看看起read方法:

     override def read(): Iterator[Product2[K, C]] = {
        val blockFetcherItr = new ShuffleBlockFetcherIterator(
          context,
          blockManager.shuffleClient,
          blockManager,
          // 获取存储数据位置的元数据
          mapOutputTracker.getMapSizesByExecutorId(handle.shuffleId, startPartition, endPartition),
          // 每次远程请求传输的最大大小
          SparkEnv.get.conf.getSizeAsMb("spark.reducer.maxSizeInFlight", "48m") * 1024 * 1024,
          SparkEnv.get.conf.getInt("spark.reducer.maxReqsInFlight", Int.MaxValue))
    
        // 用压缩加密来包装流
        val wrappedStreams = blockFetcherItr.map { case (blockId, inputStream) =>
          serializerManager.wrapStream(blockId, inputStream)
        }
      
        val serializerInstance = dep.serializer.newInstance()
    
        // 对每个流生成K/V迭代器
        val recordIter = wrappedStreams.flatMap { wrappedStream =>
           serializerInstance.deserializeStream(wrappedStream).asKeyValueIterator
        }
    
        // 每条记录读取后更新任务度量
        val readMetrics = context.taskMetrics.createTempShuffleReadMetrics()
        // 生成完整的迭代器
        val metricIter = CompletionIterator[(Any, Any), Iterator[(Any, Any)]](
          recordIter.map { record =>
            readMetrics.incRecordsRead(1)
            record
          },
          context.taskMetrics().mergeShuffleReadMetrics())
    
        // An interruptible iterator must be used here in order to support task cancellation
        val interruptibleIter = new InterruptibleIterator[(Any, Any)](context, metricIter)
    
        val aggregatedIter: Iterator[Product2[K, C]] = if (dep.aggregator.isDefined) {
          if (dep.mapSideCombine) {
            // 在map端已经聚合一次了
            val combinedKeyValuesIterator = interruptibleIter.asInstanceOf[Iterator[(K, C)]]
            dep.aggregator.get.combineCombinersByKey(combinedKeyValuesIterator, context)
          } else {
            // 只在reduce端聚合
            val keyValuesIterator = interruptibleIter.asInstanceOf[Iterator[(K, Nothing)]]
            dep.aggregator.get.combineValuesByKey(keyValuesIterator, context)
          }
        } else {
          require(!dep.mapSideCombine, "Map-side combine without Aggregator specified!")
          interruptibleIter.asInstanceOf[Iterator[Product2[K, C]]]
        }
    
        // 若需要全局排序
        dep.keyOrdering match {
          case Some(keyOrd: Ordering[K]) =>
            val sorter =
              new ExternalSorter[K, C, C](context, ordering = Some(keyOrd), serializer = dep.serializer)
            sorter.insertAll(aggregatedIter)
            context.taskMetrics().incMemoryBytesSpilled(sorter.memoryBytesSpilled)
            context.taskMetrics().incDiskBytesSpilled(sorter.diskBytesSpilled)
            context.taskMetrics().incPeakExecutionMemory(sorter.peakMemoryUsedBytes)
            CompletionIterator[Product2[K, C], Iterator[Product2[K, C]]](sorter.iterator, sorter.stop())
          case None =>
            aggregatedIter
        }
      }
    

    首先实例化了ShuffleBlockFetcherIterator对象,其中一个参数:

    mapOutputTracker.getMapSizesByExecutorId(handle.shuffleId, startPartition, endPartition)
    

    该方法获取reduce端数据的来源的元数据,返回的是 Seq[(BlockManagerId, Seq[(BlockId, Long)])],即数据是来自于哪个节点的哪些block的,并且block的数据大小是多少,看看getMapSizesByExecutorId是怎么实现的:

    def getMapSizesByExecutorId(shuffleId: Int, startPartition: Int, endPartition: Int)
          : Seq[(BlockManagerId, Seq[(BlockId, Long)])] = {
        logDebug(s"Fetching outputs for shuffle $shuffleId, partitions $startPartition-$endPartition")
        // 获取元数据信息
        val statuses = getStatuses(shuffleId)
        // 转换格式并得到指定partition的元数据信息
        statuses.synchronized {
          return MapOutputTracker.convertMapStatuses(shuffleId, startPartition, endPartition, statuses)
        }
      }
    
    • 传入shuffleId获取对应shuffle的所有元数据信息
    • 转换格式并获取指定partition的元数据

    跟进getStatuses:

    private def getStatuses(shuffleId: Int): Array[MapStatus] = {
        // 直接从mapStatuses中获取
        val statuses = mapStatuses.get(shuffleId).orNull
        if (statuses == null) {
          logInfo("Don't have map outputs for shuffle " + shuffleId + ", fetching them")
          val startTime = System.currentTimeMillis
          var fetchedStatuses: Array[MapStatus] = null
          ......
          if (fetchedStatuses == null) {
            // We won the race to fetch the statuses; do so
            logInfo("Doing the fetch; tracker endpoint = " + trackerEndpoint)
            // This try-finally prevents hangs due to timeouts:
            try {
              // 从远程获取元数据
              val fetchedBytes = askTracker[Array[Byte]](GetMapOutputStatuses(shuffleId))
              // 反序列化
              fetchedStatuses = MapOutputTracker.deserializeMapStatuses(fetchedBytes)
              logInfo("Got the output locations")
              // 加入mapStatus
              mapStatuses.put(shuffleId, fetchedStatuses)
            } finally {
              fetching.synchronized {
                fetching -= shuffleId
                fetching.notifyAll()
              }
            }
          } 
         .....
          }
        } else {
          return statuses
        }
      }
    

    若能从mapStatuses获取到则直接返回,若不能则向mapOutputTrackerMaster通信发送GetMapOutputStatuses消息来获取元数据。

    我们知道一个Executor对应一个CoarseGrainedExecutorBackend,构建CoarseGrainedExecutorBackend的时候会创建一个SparkEnv,创建SparkEnv的时候会创建一个mapOutputTracker,即mapOutputTracker和Executor一一对应,也就是每一个Executor都有一个mapOutputTracker来维护元数据信息。

    这里的mapStatuses就是mapOutputTracker保存元数据信息的,mapOutputTracker和Executor一一对应,在该Executor上完成的Shuffle Write的元数据信息都会保存在其mapStatus里面,另外通过远程获取的其他Executor上完成的Shuffle Write的元数据信息也会在当前的mapStatuses中保存。

    Executor对应的是mapOutputTrackerWorker,而Driver对应的是mapOutputTrackerMaster,两者都是在实例化SparkEnv的时候创建的,每个在Executor上完成的Shuffle Task的结果都会注册到driver端的mapOutputTrackerMaster中,即driver端的mapOutputTrackerMaster的mapStatuses保存这所有元数据信息,所以当一个Executor上的任务需要获取一个shuffle的输出时,会先在自己的mapStatuses中查找,找不到再和mapOutputTrackerMaster通信获取元数据。

    mapOutputTrackerMaster收到消息后的处理逻辑:

    case GetMapOutputStatuses(shuffleId: Int) =>
          val hostPort = context.senderAddress.hostPort
          logInfo("Asked to send map output locations for shuffle " + shuffleId + " to " + hostPort)
          val mapOutputStatuses = tracker.post(new GetMapOutputMessage(shuffleId, context))
    

    调用了tracker的post方法:

     def post(message: GetMapOutputMessage): Unit = {
        mapOutputRequests.offer(message)
      }
    

    将该Message加入了mapOutputRequests中,mapOutputRequests是一个链式阻塞队列,在mapOutputTrackerMaster初始化的时候专门启动了一个线程池来执行这些请求:

    private val threadpool: ThreadPoolExecutor = {
        val numThreads = conf.getInt("spark.shuffle.mapOutput.dispatcher.numThreads", 8)
        val pool = ThreadUtils.newDaemonFixedThreadPool(numThreads, "map-output-dispatcher")
        for (i <- 0 until numThreads) {
          pool.execute(new MessageLoop)
        }
        pool
      }
    

    看看线程处理类MessageLoop的run方法是怎么定义的:

    private class MessageLoop extends Runnable {
        override def run(): Unit = {
          try {
            while (true) {
              try {
                // 取出一个GetMapOutputMessage
                val data = mapOutputRequests.take()
                 if (data == PoisonPill) {
                  // Put PoisonPill back so that other MessageLoops can see it.
                  mapOutputRequests.offer(PoisonPill)
                  return
                }
                val context = data.context
                val shuffleId = data.shuffleId
                val hostPort = context.senderAddress.hostPort
                logDebug("Handling request to send map output locations for shuffle " + shuffleId +
                  " to " + hostPort)
                // 通过shuffleId获取对应序列化后的元数据信息
                val mapOutputStatuses = getSerializedMapOutputStatuses(shuffleId)
                // 返回数据
                context.reply(mapOutputStatuses)
              } catch {
                case NonFatal(e) => logError(e.getMessage, e)
              }
            }
          } catch {
            case ie: InterruptedException => // exit
          }
        }
      }
    

    通过shuffleId获取对应序列化后的元数据信息并返回,具体看看getSerializedMapOutputStatuses的实现:

    def getSerializedMapOutputStatuses(shuffleId: Int): Array[Byte] = {
        var statuses: Array[MapStatus] = null
        var retBytes: Array[Byte] = null
        var epochGotten: Long = -1
    
        // 从cache中检索出MapStatus,若没有则从mapStatuses中获取
        def checkCachedStatuses(): Boolean = {
          epochLock.synchronized {
            if (epoch > cacheEpoch) {
              cachedSerializedStatuses.clear()
              clearCachedBroadcast()
              cacheEpoch = epoch
            }
            cachedSerializedStatuses.get(shuffleId) match {
              case Some(bytes) =>
                retBytes = bytes
                true
              case None =>
                logDebug("cached status not found for : " + shuffleId)
                statuses = mapStatuses.getOrElse(shuffleId, Array.empty[MapStatus])
                epochGotten = epoch
                false
            }
          }
        }
    
        if (checkCachedStatuses()) return retBytes
        var shuffleIdLock = shuffleIdLocks.get(shuffleId)
        if (null == shuffleIdLock) {
          val newLock = new Object()
          // in general, this condition should be false - but good to be paranoid
          val prevLock = shuffleIdLocks.putIfAbsent(shuffleId, newLock)
          shuffleIdLock = if (null != prevLock) prevLock else newLock
        }
        // synchronize so we only serialize/broadcast it once since multiple threads call
        // in parallel
        shuffleIdLock.synchronized {
          if (checkCachedStatuses()) return retBytes
    
          // 序列化statues
          val (bytes, bcast) = MapOutputTracker.serializeMapStatuses(statuses, broadcastManager,
            isLocal, minSizeForBroadcast)
          logInfo("Size of output statuses for shuffle %d is %d bytes".format(shuffleId, bytes.length))
          // Add them into the table only if the epoch hasn't changed while we were working
          epochLock.synchronized {
            if (epoch == epochGotten) {
              cachedSerializedStatuses(shuffleId) = bytes
              if (null != bcast) cachedSerializedBroadcast(shuffleId) = bcast
            } else {
              logInfo("Epoch changed, not caching!")
              removeBroadcast(bcast)
            }
          }
          bytes
        }
      }
    

    大体思路是先从缓存中获取元数据(MapStatuses),获取到直接返回,若没有则从mapStatuses获取,获取到后将其序列化后返回,随后返回给mapOutputTrackerWorker(刚才与之通信的节点),mapOutputTracker收到回复后又将元数据序列化并加入当前Executor的mapStatuses中。

    再回到getMapSizesByExecutorId方法中,getStatuses得到shuffleID对应的所有的元数据信息后,通过convertMapStatuses方法将获得的元数据信息转化成形如Seq[(BlockManagerId, Seq[(BlockId, Long)])]格式的位置信息,用来读取指定的分区的数据:

    private def convertMapStatuses(
          shuffleId: Int,
          startPartition: Int,
          endPartition: Int,
          statuses: Array[MapStatus]): Seq[(BlockManagerId, Seq[(BlockId, Long)])] = {
        assert (statuses != null)
        // 存储指定partition的元数据
        val splitsByAddress = new HashMap[BlockManagerId, ArrayBuffer[(BlockId, Long)]]
        for ((status, mapId) <- statuses.zipWithIndex) {
          if (status == null) {
            val errorMessage = s"Missing an output location for shuffle $shuffleId"
            logError(errorMessage)
            throw new MetadataFetchFailedException(shuffleId, startPartition, errorMessage)
          } else {
            for (part <- startPartition until endPartition) {
              splitsByAddress.getOrElseUpdate(status.location, ArrayBuffer()) +=
                ((ShuffleBlockId(shuffleId, mapId, part), status.getSizeForBlock(part)))
            }
          }
        }
    
        splitsByAddress.toSeq
      }
    

    这里的参数statuses:Array[MapStatus]是前面获取的上游stage所有的shuffle Write 文件的元数据,并且是按map端的partitionId排序的,通过zipWithIndex将元素和这个元素在数组中的ID(索引号)组合成键/值对,这里的索引号即是map端的partitionId,再根据shuffleId、mapPartitionId、reducePartitionId来构建ShuffleBlockId(在map端的ShuffleBlockId构建中的reducePartitionId始终是0,因为一个ShuffleMapTask就一个Block,而这里加入的真正的reducePartitionId在后面通过index文件获取对应reduce端partition偏移量的时候需要用到),并估算得到对应数据的大小,因为后面获取远程数据的时候需要限制大小,最后返回位置信息。

    至此mapOutputTracker.getMapSizesByExecutorId(handle.shuffleId, startPartition, endPartition)方法完成,返回了指定分区对应的元数据MapStatus信息。

    在初始化对象ShuffleBlockFetcherIterator的时候调用了其初始化方法initialize():

    private[this] def initialize(): Unit = {
        // Add a task completion callback (called in both success case and failure case) to cleanup.
        context.addTaskCompletionListener(_ => cleanup())
    
        // 区分local blocks和remote blocks并返回远程请求FetchRequest
        val remoteRequests = splitLocalRemoteBlocks()
        // 将远程请求随机的加入到fetchRequests队列中
        fetchRequests ++= Utils.randomize(remoteRequests)
        assert ((0 == reqsInFlight) == (0 == bytesInFlight),
          "expected reqsInFlight = 0 but found reqsInFlight = " + reqsInFlight +
          ", expected bytesInFlight = 0 but found bytesInFlight = " + bytesInFlight)
    
        // 从fetchRequests取出远程请求,并使用sendRequest方法发送请求
        fetchUpToMaxBytes()
    
        val numFetches = remoteRequests.size - fetchRequests.size
        logInfo("Started " + numFetches + " remote fetches in" + Utils.getUsedTimeMs(startTime))
    
        // 获取本地blocks
        fetchLocalBlocks()
        logDebug("Got local blocks in " + Utils.getUsedTimeMs(startTime))
      }
    
    • 区分local blocks和remote blocks,并返回远程请求FetchRequest加入到fetchRequests队列中
    • 从fetchRequests取出远程请求,并使用sendRequest方法发送请求,获取远程数据
    • 获取本地blocks

    先看是怎么区分local blocks和remote blocks的:

    private[this] def splitLocalRemoteBlocks(): ArrayBuffer[FetchRequest] = {
        // 将一次能获取的数据最大大小/5,目的是增加并行度,最大为5个并行度
        val targetRequestSize = math.max(maxBytesInFlight / 5, 1L)
        logDebug("maxBytesInFlight: " + maxBytesInFlight + ", targetRequestSize: " + targetRequestSize)
    
        // 存储远程请求的数组
        val remoteRequests = new ArrayBuffer[FetchRequest]
    
        // Tracks total number of blocks (including zero sized blocks)
        var totalBlocks = 0
        for ((address, blockInfos) <- blocksByAddress) {
          totalBlocks += blockInfos.size
          // 若block所在executor就是当前executor,则判断为本地,否则为远程
          if (address.executorId == blockManager.blockManagerId.executorId) {
            // 过滤掉大小为0的blocks
            localBlocks ++= blockInfos.filter(_._2 != 0).map(_._1)
            numBlocksToFetch += localBlocks.size
          } else {
            val iterator = blockInfos.iterator
            var curRequestSize = 0L
            var curBlocks = new ArrayBuffer[(BlockId, Long)]
            while (iterator.hasNext) {
              val (blockId, size) = iterator.next()
              // Skip empty blocks
              if (size > 0) {
                curBlocks += ((blockId, size))
                remoteBlocks += blockId
                numBlocksToFetch += 1
                curRequestSize += size
              } else if (size < 0) {
                throw new BlockException(blockId, "Negative block size " + size)
              }
              // 当请求大小超过了限制,则创建一个FetchRequest并加入到remoteRequests中
              if (curRequestSize >= targetRequestSize) {
                // Add this FetchRequest
                remoteRequests += new FetchRequest(address, curBlocks)
                curBlocks = new ArrayBuffer[(BlockId, Long)]
                logDebug(s"Creating fetch request of $curRequestSize at $address")
                curRequestSize = 0
              }
            }
            // 将剩余的blocks创建一个FetchRequest并加入到remoteRequests中
            if (curBlocks.nonEmpty) {
              remoteRequests += new FetchRequest(address, curBlocks)
            }
          }
        }
        logInfo(s"Getting $numBlocksToFetch non-empty blocks out of $totalBlocks blocks")
        remoteRequests
      }
    
    • 为了增加在远程节点获取数据的并行度,将一个请求的大小限制除以5作为最终的大小限制,即每次最多启动5个线程去最多5个节点上读取数据
    • 判断是否是本地blocks的条件是block所在的executor和当前executor是否是同一个
    • 遍历远程数据节点(Executor节点)的blocks,在一个节点上的请求数据超过大小限制则构建一个FetchRequest并加入到remoteRequests中,最后返回远程请求remoteRequests,这里的FetchRequest是对一个请求数据的包装,包括地址和blockId及大小

    区分完local remote blocks后加入到了队列fetchRequests中,并调用fetchUpToMaxBytes()来获取远程数据:

    private def fetchUpToMaxBytes(): Unit = {
        // Send fetch requests up to maxBytesInFlight
        while (fetchRequests.nonEmpty &&
          (bytesInFlight == 0 ||
            (reqsInFlight + 1 <= maxReqsInFlight &&
              bytesInFlight + fetchRequests.front.size <= maxBytesInFlight))) {
          sendRequest(fetchRequests.dequeue())
        }
      }
    

    从fetchRequests中取出FetchRequest,并调用了sendRequest方法:

     private[this] def sendRequest(req: FetchRequest) {
        logDebug("Sending request for %d blocks (%s) from %s".format(
          req.blocks.size, Utils.bytesToString(req.size), req.address.hostPort))
        bytesInFlight += req.size
        reqsInFlight += 1
    
        // 转成map  Map[blockId,size]
        val sizeMap = req.blocks.map { case (blockId, size) => (blockId.toString, size) }.toMap
        val remainingBlocks = new HashSet[String]() ++= sizeMap.keys
        val blockIds = req.blocks.map(_._1.toString)
    
        val address = req.address
        // 通过shuffleClient的fetchBlocks方法来获取对应远程节点上的数据
        shuffleClient.fetchBlocks(address.host, address.port, address.executorId, blockIds.toArray,
          new BlockFetchingListener {
            // 将结果保存到results中
            override def onBlockFetchSuccess(blockId: String, buf: ManagedBuffer): Unit = {
              // Only add the buffer to results queue if the iterator is not zombie,
              // i.e. cleanup() has not been called yet.
              ShuffleBlockFetcherIterator.this.synchronized {
                if (!isZombie) {
                  // Increment the ref count because we need to pass this to a different thread.
                  // This needs to be released after use.
                  buf.retain()
                  remainingBlocks -= blockId
                  results.put(new SuccessFetchResult(BlockId(blockId), address, sizeMap(blockId), buf,
                    remainingBlocks.isEmpty))
                  logDebug("remainingBlocks: " + remainingBlocks)
                }
              }
              logTrace("Got remote block " + blockId + " after " + Utils.getUsedTimeMs(startTime))
            }
    
            override def onBlockFetchFailure(blockId: String, e: Throwable): Unit = {
              logError(s"Failed to get block(s) from ${req.address.host}:${req.address.port}", e)
              results.put(new FailureFetchResult(BlockId(blockId), address, e))
            }
          }
        )
      }
    

    通过shuffleClient的fetchBlocks方法来获取对应远程节点上的数据,默认是通过NettyBlockTransferService的fetchBlocks方法实现的,不管是成功还是失败都将构建SuccessFetchResult & FailureFetchResult 结果放入results中。

    获取完远程的数据接着通过fetchLocalBlocks()方法来获取本地的blocks信息:

    private[this] def fetchLocalBlocks() {
        val iter = localBlocks.iterator
        while (iter.hasNext) {
          val blockId = iter.next()
          try {
            val buf = blockManager.getBlockData(blockId)
            shuffleMetrics.incLocalBlocksFetched(1)
            shuffleMetrics.incLocalBytesRead(buf.size)
            buf.retain()
            results.put(new SuccessFetchResult(blockId, blockManager.blockManagerId, 0, buf, false))
          } catch {
            case e: Exception =>
              // If we see an exception, stop immediately.
              logError(s"Error occurred while fetching local blocks", e)
              results.put(new FailureFetchResult(blockId, blockManager.blockManagerId, e))
              return
          }
        }
      }
    

    迭代需要获取的block,直接从blockManager中获取数据,并通过结果数据构建SuccessFetchResult或者FailureFetchResult放入results中,看看在blockManager.getBlockData(blockId)的实现:

    override def getBlockData(blockId: BlockId): ManagedBuffer = {
        if (blockId.isShuffle) {
          shuffleManager.shuffleBlockResolver.getBlockData(blockId.asInstanceOf[ShuffleBlockId])
        } else {
          getLocalBytes(blockId) match {
            case Some(buffer) => new BlockManagerManagedBuffer(blockInfoManager, blockId, buffer)
            case None =>
              // If this block manager receives a request for a block that it doesn't have then it's
              // likely that the master has outdated block statuses for this block. Therefore, we send
              // an RPC so that this block is marked as being unavailable from this block manager.
              reportBlockStatus(blockId, BlockStatus.empty)
              throw new BlockNotFoundException(blockId.toString)
          }
        }
      }
    

    再看看getBlockData方法:

    override def getBlockData(blockId: ShuffleBlockId): ManagedBuffer = {
        // 根据ShuffleID和MapID获取索引文件
        val indexFile = getIndexFile(blockId.shuffleId, blockId.mapId)
        val in = new DataInputStream(new FileInputStream(indexFile))
        try {
          // 跳到对应Block的数据区
          ByteStreams.skipFully(in, blockId.reduceId * 8)
          // partition对应的开始offset
          val offset = in.readLong()
          // partition对应的结束offset
          val nextOffset = in.readLong()
          new FileSegmentManagedBuffer(
            transportConf,
            getDataFile(blockId.shuffleId, blockId.mapId),
            offset,
            nextOffset - offset)
        } finally {
          in.close()
        }
      }
    

    根据shuffleId和mapId获取index文件,并创建一个读文件的文件流,根据block的reduceId(上面获取对应partition元数据的时候提到过)跳过对应的Block的数据区,先后获取开始和结束的offset,然后在数据文件中读取数据。

    得到所有数据结果result后,再回到read()方法中:

     override def read(): Iterator[Product2[K, C]] = {
        val blockFetcherItr = new ShuffleBlockFetcherIterator(
          context,
          blockManager.shuffleClient,
          blockManager,
          // 与mapOutputTrackerMaster通信获取存储数据位置的元数据
          mapOutputTracker.getMapSizesByExecutorId(handle.shuffleId, startPartition, endPartition),
          // 每次传输的最大大小
          SparkEnv.get.conf.getSizeAsMb("spark.reducer.maxSizeInFlight", "48m") * 1024 * 1024,
          SparkEnv.get.conf.getInt("spark.reducer.maxReqsInFlight", Int.MaxValue))
    
        // 用压缩加密来包装流
        val wrappedStreams = blockFetcherItr.map { case (blockId, inputStream) =>
          serializerManager.wrapStream(blockId, inputStream)
        }
      
        val serializerInstance = dep.serializer.newInstance()
    
        // 对每个流生成K/V迭代器
        val recordIter = wrappedStreams.flatMap { wrappedStream =>
           serializerInstance.deserializeStream(wrappedStream).asKeyValueIterator
        }
    
        // 每条记录读取后更新任务度量
        val readMetrics = context.taskMetrics.createTempShuffleReadMetrics()
        // 生成完整的迭代器
        val metricIter = CompletionIterator[(Any, Any), Iterator[(Any, Any)]](
          recordIter.map { record =>
            readMetrics.incRecordsRead(1)
            record
          },
          context.taskMetrics().mergeShuffleReadMetrics())
    
        // An interruptible iterator must be used here in order to support task cancellation
        val interruptibleIter = new InterruptibleIterator[(Any, Any)](context, metricIter)
    
        val aggregatedIter: Iterator[Product2[K, C]] = if (dep.aggregator.isDefined) {
          if (dep.mapSideCombine) {
            // 在map端已经聚合一次了
            val combinedKeyValuesIterator = interruptibleIter.asInstanceOf[Iterator[(K, C)]]
            dep.aggregator.get.combineCombinersByKey(combinedKeyValuesIterator, context)
          } else {
            // 只在reduce端聚合
            val keyValuesIterator = interruptibleIter.asInstanceOf[Iterator[(K, Nothing)]]
            dep.aggregator.get.combineValuesByKey(keyValuesIterator, context)
          }
        } else {
          require(!dep.mapSideCombine, "Map-side combine without Aggregator specified!")
          interruptibleIter.asInstanceOf[Iterator[Product2[K, C]]]
        }
    
        // 若需要全局排序
        dep.keyOrdering match {
          case Some(keyOrd: Ordering[K]) =>
            val sorter =
              new ExternalSorter[K, C, C](context, ordering = Some(keyOrd), serializer = dep.serializer)
            sorter.insertAll(aggregatedIter)
            context.taskMetrics().incMemoryBytesSpilled(sorter.memoryBytesSpilled)
            context.taskMetrics().incDiskBytesSpilled(sorter.diskBytesSpilled)
            context.taskMetrics().incPeakExecutionMemory(sorter.peakMemoryUsedBytes)
            CompletionIterator[Product2[K, C], Iterator[Product2[K, C]]](sorter.iterator, sorter.stop())
          case None =>
            aggregatedIter
        }
      }
    

    这里的ShuffleBlockFetcherIterator继承了Iterator,results可以被迭代,在其next()方法中将FetchResult以(blockId,inputStream)的形式返回:

    case SuccessFetchResult(blockId, address, _, buf, _) =>
            try {
              (result.blockId, new BufferReleasingInputStream(buf.createInputStream(), this))
            } catch {
              case NonFatal(t) =>
                throwFetchFailedException(blockId, address, t)
            }
    

    在read()方法的后半部分会进行聚合和排序,和Shuffle Write部分很类似,这里大致描述一下。

    在需要聚合的前提下,有map端聚合的时候执行combineCombinersByKey,没有则执行combineValuesByKey,但最终都调用了ExternalAppendOnlyMap的insertAll(iter)方法:

    def combineCombinersByKey(
          iter: Iterator[_ <: Product2[K, C]],
          context: TaskContext): Iterator[(K, C)] = {
        val combiners = new ExternalAppendOnlyMap[K, C, C](identity, mergeCombiners, mergeCombiners)
        combiners.insertAll(iter)
        updateMetrics(context, combiners)
        combiners.iterator
      }
    
    def combineValuesByKey(
          iter: Iterator[_ <: Product2[K, V]],
          context: TaskContext): Iterator[(K, C)] = {
        val combiners = new ExternalAppendOnlyMap[K, V, C](createCombiner, mergeValue, mergeCombiners)
        combiners.insertAll(iter)
        updateMetrics(context, combiners)
        combiners.iterator
      }
    
    def insertAll(entries: Iterator[Product2[K, V]]): Unit = {
        if (currentMap == null) {
          throw new IllegalStateException(
            "Cannot insert new elements into a map after calling iterator")
        }
        // An update function for the map that we reuse across entries to avoid allocating
        // a new closure each time
        var curEntry: Product2[K, V] = null
        val update: (Boolean, C) => C = (hadVal, oldVal) => {
          if (hadVal) mergeValue(oldVal, curEntry._2) else createCombiner(curEntry._2)
        }
    
        while (entries.hasNext) {
          curEntry = entries.next()
          val estimatedSize = currentMap.estimateSize()
          if (estimatedSize > _peakMemoryUsedBytes) {
            _peakMemoryUsedBytes = estimatedSize
          }
          if (maybeSpill(currentMap, estimatedSize)) {
            currentMap = new SizeTrackingAppendOnlyMap[K, C]
          }
          currentMap.changeValue(curEntry._1, update)
          addElementsRead()
        }
      }
    

    在里面的迭代最终都会调用上面提到的ShuffleBlockFetcherIterator的next方法来获取数据。

    每次update&insert也会估算currentMap的大小,并判断是否需要溢写到磁盘文件,若需要则将map中的数据根据定义的keyComparator对key进行排序后返回一个迭代器,然后写到一个临时的磁盘文件,然后新建一个map来放新的数据。

    执行完combiners[ExternalAppendOnlyMap]的insertAll后,调用其iterator来返回一个代表一个完整partition数据(内存及spillFile)的迭代器:

    override def iterator: Iterator[(K, C)] = {
        if (currentMap == null) {
          throw new IllegalStateException(
            "ExternalAppendOnlyMap.iterator is destructive and should only be called once.")
        }
        if (spilledMaps.isEmpty) {
          CompletionIterator[(K, C), Iterator[(K, C)]](
            destructiveIterator(currentMap.iterator), freeCurrentMap())
        } else {
          new ExternalIterator()
        }
      }
    

    跟进ExternalIterator类的实例化:

    // A queue that maintains a buffer for each stream we are currently merging
        // This queue maintains the invariant that it only contains non-empty buffers
        private val mergeHeap = new mutable.PriorityQueue[StreamBuffer]
    
        // Input streams are derived both from the in-memory map and spilled maps on disk
        // The in-memory map is sorted in place, while the spilled maps are already in sorted order
        private val sortedMap = CompletionIterator[(K, C), Iterator[(K, C)]](destructiveIterator(
          currentMap.destructiveSortedIterator(keyComparator)), freeCurrentMap())
        private val inputStreams = (Seq(sortedMap) ++ spilledMaps).map(it => it.buffered)
    
        inputStreams.foreach { it =>
          val kcPairs = new ArrayBuffer[(K, C)]
          readNextHashCode(it, kcPairs)
          if (kcPairs.length > 0) {
            mergeHeap.enqueue(new StreamBuffer(it, kcPairs))
          }
        }
    

    将currentMap中的数据经过排序后和spillFile数据的iterator组合在一起得到inputStreams ,迭代这个inputStreams ,将所有数据都保存在mergeHeadp中,在ExternalIterator方法的next()方法中将被访问到。

    最后若需要对数据进行全局的排序,则通过只有排序参数的ExternalSorter的insertAll方法来进行排序,和Shuffle Write一样的这里就不细讲了。

    最终返回一个指定partition所有数据的一个迭代器。

    相关文章

      网友评论

        本文标题:[spark] Shuffle Read解析 (Sort Bas

        本文链接:https://www.haomeiwen.com/subject/ntvxvxtx.html