美文网首页
spark源码阅读——shuffle读

spark源码阅读——shuffle读

作者: WJL3333 | 来源:发表于2018-08-18 17:23 被阅读19次

    DAGScheduler在拆分任务的时候如果发现需要shuffle则会把之前RDD运算产生的结果输出到本地磁盘中(详细的会在以后的文章分析)。

    紧接着就需要对Shuffle后的结果分别进行运算了(比如说count
    那么接着之前的RDD会有一个ShuffledRDD来处理shuffle之后的结果。
    (实际上是一个新的Stage

    同样在这个Stage会把任务拆分成Task并发送给Executor

    这里拆分成的TaskResultTask实际上也很简单,任务反序列化之后执行ShuffledRDD.iterator -> ShuffledRDD.compute

    override def compute(split: Partition, context: TaskContext): Iterator[(K, C)] = {
        val dep = dependencies.head.asInstanceOf[ShuffleDependency[K, V, C]]
        SparkEnv.get.shuffleManager.getReader(dep.shuffleHandle, split.index, split.index + 1, context)
          .read()
          .asInstanceOf[Iterator[(K, C)]]
      }
    

    ShuffleManager获取一个Reader读取之前Shuffle输出的数据进行运算。
    实际这个Reader是一个BlockStoreShuffleReader
    这个类会做什么呢?

    • 首先任务是计算我这个partition的结果,我需要知道之前依赖的partition的数据的位置(MapOutputTracker)
    • 根据位置获取依赖的数据。(BlockManager)
    • 如果需要combine则执行聚合逻辑
    • 如果需要排序则排序(ExternalSorter)
    /**
     * Fetches and reads the partitions in range [startPartition, endPartition) from a shuffle by
     * requesting them from other nodes' block stores.
     */
    private[spark] class BlockStoreShuffleReader[K, C](
        handle: BaseShuffleHandle[K, _, C],
        startPartition: Int,
        endPartition: Int,
        context: TaskContext,
        serializerManager: SerializerManager = SparkEnv.get.serializerManager,
        blockManager: BlockManager = SparkEnv.get.blockManager,
        mapOutputTracker: MapOutputTracker = SparkEnv.get.mapOutputTracker)
      extends ShuffleReader[K, C] with Logging {
    
    override def read(): Iterator[Product2[K, C]] = {
        val wrappedStreams = new ShuffleBlockFetcherIterator(
          context,
          blockManager.shuffleClient,
          blockManager,
          mapOutputTracker.getMapSizesByExecutorId(handle.shuffleId, startPartition, endPartition),
          serializerManager.wrapStream,
          ...
       )
    
        val serializerInstance = dep.serializer.newInstance()
    
        // Create a key/value iterator for each stream
        val recordIter = wrappedStreams.flatMap { case (blockId, wrappedStream) =>
           ...
          serializerInstance.deserializeStream(wrappedStream).asKeyValueIterator
        }
    
           ...
        val interruptibleIter = new InterruptibleIterator[(Any, Any)](context, metricIter)
    
        val aggregatedIter: Iterator[Product2[K, C]] = if (dep.aggregator.isDefined) {
          if (dep.mapSideCombine) {
             ...
            val combinedKeyValuesIterator = interruptibleIter.asInstanceOf[Iterator[(K, C)]]
            dep.aggregator.get.combineCombinersByKey(combinedKeyValuesIterator, context)
          } else {
            val keyValuesIterator = interruptibleIter.asInstanceOf[Iterator[(K, Nothing)]]
            dep.aggregator.get.combineValuesByKey(keyValuesIterator, context)
          }
        } else {
           ...      
           interruptibleIter.asInstanceOf[Iterator[Product2[K, C]]]
        }
    
        // Sort the output if there is a sort ordering defined.
        dep.keyOrdering match {
          case Some(keyOrd: Ordering[K]) =>
            // Create an ExternalSorter to sort the data.
            val sorter =
              new ExternalSorter[K, C, C](context, ordering = Some(keyOrd), serializer = dep.serializer)
            sorter.insertAll(aggregatedIter)
            ...
            CompletionIterator[Product2[K, C], Iterator[Product2[K, C]]](sorter.iterator, sorter.stop())
          case None =>
            aggregatedIter
        }
      }
    

    mapOutputTracker.getMapSizesByExecutorId(handle.shuffleId, startPartition, endPartition)
    这个方法返回的参数类型是Seq[(BlockManagerId, Seq[(BlockId, Long)])]
    也就是说在哪个Executor的哪个位置上保存着需要的数据信息,数据大小是多少。

    (每个Executor都有一个SparkEnv,每个SparkEnv都包含一个BlockManagerDriver端的BlockManager是Master,Executor端的BlockManager是slave,集群内部的BlockManager构成了一个master-slave模式的集群,以后会说)

    ShuffleBlockFetcherIterator

    这个对象实际上就负责的从远程节点拉取所有数据的任务。
    首先到了initialize方法

    • 注册回调,方便任务结束后清理内存(ByteBuffer)
    • 分离请求,有的数据实际可能保存在本地
    • 获取远程数据块
    • 获取本地数据块
    private[this] def initialize(): Unit = {
        
        context.addTaskCompletionListener(_ => cleanup())
        ...
        val remoteRequests = splitLocalRemoteBlocks()
        ...
        fetchRequests ++= Utils.randomize(remoteRequests)
        ...
        fetchUpToMaxBytes()
        ...
        fetchLocalBlocks()
        ...
      }
    
    

    这里将数据块信息拆分成获取任务的列表,这里有一个优化,为了加快获取速度,会将同一个文件拆分成多个请求同时获取。

    private[this] def splitLocalRemoteBlocks(): ArrayBuffer[FetchRequest] = {
        // Make remote requests at most maxBytesInFlight / 5 in length; the reason to keep them
        // smaller than maxBytesInFlight is to allow multiple, parallel fetches from up to 5
        // nodes, rather than blocking on reading output from one node.
    
        val targetRequestSize = math.max(maxBytesInFlight / 5, 1L)
        logDebug("maxBytesInFlight: " + maxBytesInFlight + ", targetRequestSize: " + targetRequestSize
          + ", maxBlocksInFlightPerAddress: " + maxBlocksInFlightPerAddress)
    
        // Split local and remote blocks. Remote blocks are further split into FetchRequests of size
        // at most maxBytesInFlight in order to limit the amount of data in flight.
        val remoteRequests = new ArrayBuffer[FetchRequest]
    
        // Tracks total number of blocks (including zero sized blocks)
        var totalBlocks = 0
        for ((address, blockInfos) <- blocksByAddress) {
          totalBlocks += blockInfos.size
          if (address.executorId == blockManager.blockManagerId.executorId) {
            // Filter out zero-sized blocks
            localBlocks ++= blockInfos.filter(_._2 != 0).map(_._1)
            numBlocksToFetch += localBlocks.size
          } else {
            val iterator = blockInfos.iterator
            var curRequestSize = 0L
            var curBlocks = new ArrayBuffer[(BlockId, Long)]
            while (iterator.hasNext) {
              val (blockId, size) = iterator.next()
              if (size > 0) {
                curBlocks += ((blockId, size))
                remoteBlocks += blockId
                numBlocksToFetch += 1
                curRequestSize += size
              } else if (size < 0) {
                throw new BlockException(blockId, "Negative block size " + size)
              }
              if (curRequestSize >= targetRequestSize ||
                  curBlocks.size >= maxBlocksInFlightPerAddress) {
                remoteRequests += new FetchRequest(address, curBlocks)
                curBlocks = new ArrayBuffer[(BlockId, Long)]
                curRequestSize = 0
              }
            }
            // Add in the final request
            if (curBlocks.nonEmpty) {
              remoteRequests += new FetchRequest(address, curBlocks)
            }
          }
        }
        remoteRequests
      }
    
    

    拆分完获取任务就要直接开始获取任务了。这里是有一个请求速率控制的机制在里面,分别是maxBytesInFlight和maxRequestInFlight,如果能发送请求则发送,否则放到延迟队列中,等待下一次调用这个方法的时候去发送请求。

    
    private def fetchUpToMaxBytes(): Unit = {
        // Send fetch requests up to maxBytesInFlight. If you cannot fetch from a remote host
        // immediately, defer the request until the next time it can be processed.
    
        // Process any outstanding deferred fetch requests if possible.
        if (deferredFetchRequests.nonEmpty) {
          for ((remoteAddress, defReqQueue) <- deferredFetchRequests) {
            while (isRemoteBlockFetchable(defReqQueue) &&
                !isRemoteAddressMaxedOut(remoteAddress, defReqQueue.front)) {
              val request = defReqQueue.dequeue()
              logDebug(s"Processing deferred fetch request for $remoteAddress with "
                + s"${request.blocks.length} blocks")
              send(remoteAddress, request)
              if (defReqQueue.isEmpty) {
                deferredFetchRequests -= remoteAddress
              }
            }
          }
        }
    
        // Process any regular fetch requests if possible.
        while (isRemoteBlockFetchable(fetchRequests)) {
          val request = fetchRequests.dequeue()
          val remoteAddress = request.address
          if (isRemoteAddressMaxedOut(remoteAddress, request)) {
            logDebug(s"Deferring fetch request for $remoteAddress with ${request.blocks.size} blocks")
            val defReqQueue = deferredFetchRequests.getOrElse(remoteAddress, new Queue[FetchRequest]())
            defReqQueue.enqueue(request)
            deferredFetchRequests(remoteAddress) = defReqQueue
          } else {
            send(remoteAddress, request)
          }
        }
    
        def send(remoteAddress: BlockManagerId, request: FetchRequest): Unit = {
          sendRequest(request)
          numBlocksInFlightPerAddress(remoteAddress) =
            numBlocksInFlightPerAddress.getOrElse(remoteAddress, 0) + request.blocks.size
        }
    
        def isRemoteBlockFetchable(fetchReqQueue: Queue[FetchRequest]): Boolean = {
          fetchReqQueue.nonEmpty &&
            (bytesInFlight == 0 ||
              (reqsInFlight + 1 <= maxReqsInFlight &&
                bytesInFlight + fetchReqQueue.front.size <= maxBytesInFlight))
        }
    
        // Checks if sending a new fetch request will exceed the max no. of blocks being fetched from a
        // given remote address.
        def isRemoteAddressMaxedOut(remoteAddress: BlockManagerId, request: FetchRequest): Boolean = {
          numBlocksInFlightPerAddress.getOrElse(remoteAddress, 0) + request.blocks.size >
            maxBlocksInFlightPerAddress
        }
      }
    
    

    发送请求实际借助了ShuffleClient这个类,这个类会使用NettyBlockTransferService这个类向远程的BlockManager发起获取数据块请求。
    整个过程是异步的,并在发请求的时候增加了回调。回调便是在获取结束之后把结果放到一个LinkedBlockingQueue里面。

    之后便是获取本地的数据块,从本地的BlockManager直接获取即可。

    但实际上这个类是一个Iterator,会不断被外部调用next()方法。
    next方法实际是阻塞的,因为如果results这个队列是空的则一直阻塞在这里。
    每次从队列中获取一个获取结果之后做相应的处理,包装成一个InputStream
    每次调用next的时候都会获取一个BlockId和相应的文件流而无需考虑这个文件块是否是远程和本地,因为每次调用next的时候都会调用这个fetchUpToMaxBytes方法,保证远程数据可以一直被获取。

    override def next(): (BlockId, InputStream) = {
        if (!hasNext) {
          throw new NoSuchElementException
        }
        numBlocksProcessed += 1
    
        var result: FetchResult = null
        var input: InputStream = null
        while (result == null) {
               ...
          result = results.take()
               ...
          result match {
            case r @ SuccessFetchResult(blockId, address, size, buf, isNetworkReqDone) =>
              if (address != blockManager.blockManagerId) {
                numBlocksInFlightPerAddress(address) = numBlocksInFlightPerAddress(address) - 1
               ...
              }
              bytesInFlight -= size
              if (isNetworkReqDone) {
                reqsInFlight -= 1
              }
    
              val in = try {
                buf.createInputStream()
              } catch {
                case e: IOException =>
                  buf.release()
                  throwFetchFailedException(blockId, address, e)
              }
    
              input = streamWrapper(blockId, in)
               ...
              }
          }
          fetchUpToMaxBytes()   // <-----------------
        }
    
        currentResult = result.asInstanceOf[SuccessFetchResult]
        (currentResult.blockId, new BufferReleasingInputStream(input, this))
      }
    

    流程梳理

    ResultTaskExecutor上运行,调用ShuffledRDDiterator方法。这个方法从ShuffleManager获取一个BlockStoreShuffleReader,这个Reader内部负责了获取远程Shuffle输出文件的任务,获取之后根据combine,排序等处理数据,完成后续的运算。

    相关文章

      网友评论

          本文标题:spark源码阅读——shuffle读

          本文链接:https://www.haomeiwen.com/subject/uqsfiftx.html