美文网首页
Spark 源码浅析之 Shuffle 读部分

Spark 源码浅析之 Shuffle 读部分

作者: 越过山丘xyz | 来源:发表于2019-02-15 16:36 被阅读0次

    Shuffle Read

    在 Task 实例化的时候就会调用 runTask() 方法运行任务,runTask() 方法中会调用 RDD.getOrCompute() 方法来进行任务的运算工作:

    private[spark] def getOrCompute(partition: Partition, context: TaskContext): Iterator[T] = {
      val blockId = RDDBlockId(id, partition.index)
      var readCachedBlock = true
        
      SparkEnv.get.blockManager.getOrElseUpdate(blockId, storageLevel, elementClassTag, () => {
        readCachedBlock = false
        computeOrReadCheckpoint(partition, context)
      }) match {
        // ...
      }
    }
    

    computeOrReadCheckpoint() 方法回会先判断这个 RDD 是否 checkpoint 和物化过,如果没有就会调用 compute() 方法进行计算操作。

    类关系

    ShuffledRDD

    流程概览

    对 Shuffle-Read 进行剖析,我们需要从 ShuffleRDD.compute() 方法入手:

    override def compute(split: Partition, context: TaskContext): Iterator[(K, C)] = {
      val dep = dependencies.head.asInstanceOf[ShuffleDependency[K, V, C]]
      // 从 ShuffleManager 中获取 Reader
      SparkEnv.get.shuffleManager.getReader(dep.shuffleHandle, split.index, split.index + 1, context)
        // 调用 Reader 的 read() 方法
        .read()
        // 将读到的数据的迭代器返回
        .asInstanceOf[Iterator[(K, C)]]
    }
    

    Spark 2.2.3 默认使用 SortShuffleManager 作为 Shuffle 管理器,SortShuffleManager.getReader() 的实现细节:

    override def getReader[K, C](
        handle: ShuffleHandle,
        startPartition: Int,
        endPartition: Int,
        context: TaskContext): ShuffleReader[K, C] = {
      // 实例化 BlockStoreShuffleReader 作为 Reader
      new BlockStoreShuffleReader(
        handle.asInstanceOf[BaseShuffleHandle[K, _, C]], startPartition, endPartition, context)
    }
    

    从 ShuffleManager 获取到 BlockStoreShuffleReader 后调用了其 read() 方法:

    override def read(): Iterator[Product2[K, C]] = {
      // 实例化 ShuffleBlockFetcherIterator
      val wrappedStreams = new ShuffleBlockFetcherIterator(
        context,
        // 传入 RPC 通信端
        blockManager.shuffleClient,
        blockManager,
        // 获取该 ReduceTask 的数据来源的元数据信息
        mapOutputTracker.getMapSizesByExecutorId(handle.shuffleId, startPartition, endPartition),
        serializerManager.wrapStream,
        // 从 Map 端一次拉取的最大数据量
        SparkEnv.get.conf.getSizeAsMb("spark.reducer.maxSizeInFlight", "48m") * 1024 * 1024,
        SparkEnv.get.conf.getInt("spark.reducer.maxReqsInFlight", Int.MaxValue),
        // 每次拉取的最大请求地址数
        // 通过 spark.reducer.maxBlocksInFlightPerAddress 来配置
        // 默认为 Int.MaxValue
        SparkEnv.get.conf.get(config.REDUCER_MAX_BLOCKS_IN_FLIGHT_PER_ADDRESS),
        SparkEnv.get.conf.get(config.REDUCER_MAX_REQ_SIZE_SHUFFLE_TO_MEM),
        SparkEnv.get.conf.getBoolean("spark.shuffle.detectCorrupt", true))
    
      val serializerInstance = dep.serializer.newInstance()
    
      // 将 reduce 数据转换成 key-value 迭代器
      val recordIter = wrappedStreams.flatMap { case (blockId, wrappedStream) =>
        // Note: the asKeyValueIterator below wraps a key/value iterator inside of a
        // NextIterator. The NextIterator makes sure that close() is called on the
        // underlying InputStream when all records have been read.
        serializerInstance.deserializeStream(wrappedStream).asKeyValueIterator
      }
    
      // ...
    
      // An interruptible iterator must be used here in order to support task cancellation
      val interruptibleIter = new InterruptibleIterator[(Any, Any)](context, metricIter)
    
      // 聚合操作
      // 稍后分析
      val aggregatedIter: Iterator[Product2[K, C]] = if (dep.aggregator.isDefined) {
        if (dep.mapSideCombine) {
          // We are reading values that are already combined
          // 在 map 端聚合过
          // 创建聚合过的 key-value 迭代器
          val combinedKeyValuesIterator = interruptibleIter.asInstanceOf[Iterator[(K, C)]]
          // 聚合
          dep.aggregator.get.combineCombinersByKey(combinedKeyValuesIterator, context)
        } else {
          // We don't know the value type, but also don't care -- the dependency *should*
          // have made sure its compatible w/ this aggregator, which will convert the value
          // type to the combined type C
          // 在 map 端没有聚合过
          val keyValuesIterator = interruptibleIter.asInstanceOf[Iterator[(K, Nothing)]]
          dep.aggregator.get.combineValuesByKey(keyValuesIterator, context)
        }
      } else {
        require(!dep.mapSideCombine, "Map-side combine without Aggregator specified!")
        interruptibleIter.asInstanceOf[Iterator[Product2[K, C]]]
      }
    
      // 如果需要排序,对聚合后的数据进行排序操作
      // 返回 CompletionIterator
      dep.keyOrdering match {
        case Some(keyOrd: Ordering[K]) =>
          // Create an ExternalSorter to sort the data.
          val sorter =
            new ExternalSorter[K, C, C](context, ordering = Some(keyOrd), serializer = dep.serializer)
          sorter.insertAll(aggregatedIter)
          context.taskMetrics().incMemoryBytesSpilled(sorter.memoryBytesSpilled)
          context.taskMetrics().incDiskBytesSpilled(sorter.diskBytesSpilled)
          context.taskMetrics().incPeakExecutionMemory(sorter.peakMemoryUsedBytes)
          CompletionIterator[Product2[K, C], Iterator[Product2[K, C]]](sorter.iterator, sorter.stop())
        case None =>
          aggregatedIter
      }
    }
    

    由于这部分代码过多,我将其分为 ShuffleBlockFetcherIterator 的初始化、aggregatedIter 聚合数据迭代器的生成和排序数据迭代器的生成,这三部分进行分别探讨。

    1. ShuffleBlockFetcherIterator 的初始化

    这部分我们剖析下 ShuffleBlockFetcherIterator 实例化时传入的连个参数,分别为 blockManager.shuffleClient 和 mapOutputTracker.getMapSizesByExecutorId(...):

    // 实例化 ShuffleBlockFetcherIterator
    val wrappedStreams = new ShuffleBlockFetcherIterator(
      context,
      // 传入 RPC 通信端
      blockManager.shuffleClient,
      blockManager,
      // 获取该 ReduceTask 的数据来源的元数据信息
      mapOutputTracker.getMapSizesByExecutorId(handle.shuffleId, startPartition, endPartition),
      serializerManager.wrapStream,
      // 从 Map 端一次拉取的最大数据量
      SparkEnv.get.conf.getSizeAsMb("spark.reducer.maxSizeInFlight", "48m") * 1024 * 1024,
      SparkEnv.get.conf.getInt("spark.reducer.maxReqsInFlight", Int.MaxValue),
      // 每次拉取的最大请求地址数
      // 通过 spark.reducer.maxBlocksInFlightPerAddress 来配置
      // 默认为 Int.MaxValue
      SparkEnv.get.conf.get(config.REDUCER_MAX_BLOCKS_IN_FLIGHT_PER_ADDRESS),
      SparkEnv.get.conf.get(config.REDUCER_MAX_REQ_SIZE_SHUFFLE_TO_MEM),
      SparkEnv.get.conf.getBoolean("spark.shuffle.detectCorrupt", true))
    

    blockManager.shuffleClient:

    // externalShuffleServiceEnabled 默认为 false
    // 通过在 conf 配置 spark.shuffle.service.enabled 可更改
    private[spark] val shuffleClient = if (externalShuffleServiceEnabled) {
      val transConf = SparkTransportConf.fromSparkConf(conf, "shuffle", numUsableCores)
      new ExternalShuffleClient(transConf, securityManager, securityManager.isAuthenticationEnabled())
    } else {
      // NettyBlockTransferService 的实例化对象
      // 使用 Netty 作为通信框架
      blockTransferService
    }
    

    默认情况下,ShuffleBlockFetcherIterator 使用 Netty 作为通信服务框架。

    MapOutputTracker 类关系

    mapOutputTracker 是 MapOutputTrackerWorker 的实例化对象,其 etMapSizesByExecutorId(...) 方法实现细节:

    override def getMapSizesByExecutorId(shuffleId: Int, startPartition: Int, endPartition: Int)
        : Seq[(BlockManagerId, Seq[(BlockId, Long)])] = {
      logDebug(s"Fetching outputs for shuffle $shuffleId, partitions $startPartition-$endPartition")
      // 获取要拉取数据的元数据信息
      val statuses = getStatuses(shuffleId)
      try {
        // 将元数据信息转换为 Seq[(BlockManagerId, Seq[(BlockId, Long)])]
        MapOutputTracker.convertMapStatuses(shuffleId, startPartition, endPartition, statuses)
      } catch {
        // ..
      }
    }
    

    getStatuses() 的实现细节:

    private def getStatuses(shuffleId: Int): Array[MapStatus] = {
      // mapStatuses 为 ConcurrentHashMap 的实例化对象
      // 从 map 缓存中先拿
      val statuses = mapStatuses.get(shuffleId).orNull
      if (statuses == null) {
        // 缓存中没有
          
        fetching.synchronized {
          // fetching 是 HashSet 的实例化对象
          // fetching 是正在获取取元数据信息的 shuffleId 的集合
          // 也就是其他线程正在获取相同 shuffleId 的要拉取数据的元数据信息
          while (fetching.contains(shuffleId)) {
            // 有其它线程正在拉取元数据信息
            try {
              fetching.wait()
            } catch {
              case e: InterruptedException =>
            }
          }
    
          // 再次尝试从 map 缓存中获取
          // 因为有的线程可能是被唤醒的
          fetchedStatuses = mapStatuses.get(shuffleId).orNull
          if (fetchedStatuses == null) {
            // 如果没有获取到,将 shuffleId 加入到拉取集合中
            // 防止重复拉取
            // 这也正是 fetching.synchronized 的目的
            fetching += shuffleId
          }
        }
    
        if (fetchedStatuses == null) {
          // 真正的进行元数据信息的拉取工作
          try {
            // 向 MapOutputTrackerMasterEndpoint 发送一个获取元数据信息的请求
            // 这里发送的是同步请求
            // 由 MapOutputTrackerMasterEndpoint.receiveAndReply() 处理
            val fetchedBytes = askTracker[Array[Byte]](GetMapOutputStatuses(shuffleId))
            // 反序列化
            fetchedStatuses = MapOutputTracker.deserializeMapStatuses(fetchedBytes)
            // 加入到缓存中
            mapStatuses.put(shuffleId, fetchedStatuses)
          } finally {
            fetching.synchronized {
              // 从 fetching 中移除
              fetching -= shuffleId
              fetching.notifyAll()
            }
          }
        }
    
        if (fetchedStatuses != null) {
          fetchedStatuses
        } else {
          throw new MetadataFetchFailedException(
            shuffleId, -1, "Missing all output locations for shuffle " + shuffleId)
        }
      } else {
        statuses
      }
    }
    

    我们再看看 MapOutputTrackerMasterEndpoint 在收到 GetMapOutputStatuses 消息后,会做哪些工作:

    override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = {
      // 处理 GetMapOutputStatuses 消息
      case GetMapOutputStatuses(shuffleId: Int) =>
        val hostPort = context.senderAddress.hostPort
        // tracker 是 MapOutputTrackerMaster 的实例化对象 
        val mapOutputStatuses = tracker.post(new GetMapOutputMessage(shuffleId, context))
    
      case StopMapOutputTracker =>
        context.reply(true)
        stop()
    }
    

    tracker 是 MapOutputTrackerMaster 的实例化对象,其 post() 方法的实现细节:

    def post(message: GetMapOutputMessage): Unit = {
      // mapOutputRequests 是 LinkedBlockingQueue 的实例化对象
      // 将消息放入到队列中
      // 由后台线程去处理
      mapOutputRequests.offer(message)
    }
    

    MessageLoop 负责处理 MapOutputTrackerMaster 加入队列中的消息:

    private class MessageLoop extends Runnable {
      override def run(): Unit = {
        try {
          while (true) {
            try {
              // 取出消息
              val data = mapOutputRequests.take()
               if (data == PoisonPill) {
                mapOutputRequests.offer(PoisonPill)
                return
              }
              // 获取基本信息
              val context = data.context
              val shuffleId = data.shuffleId
              val hostPort = context.senderAddress.hostPort
              // shuffleStatuses 为 ConcurrentHashMap 的实例化对象
              // 从缓存中取出 shuffleId 对应的元数据信息
              val shuffleStatus = shuffleStatuses.get(shuffleId).head
              context.reply(
                // 序列化并应答
                shuffleStatus.serializedMapStatus(broadcastManager, isLocal, minSizeForBroadcast))
            } catch {
              case NonFatal(e) => logError(e.getMessage, e)
            }
          }
        } catch {
          case ie: InterruptedException => // exit
        }
      }
      }
    

    这样,在 ShuffleBlockFetcherIterator 中既有通信框架,又有要拉取数据的元数据信息了,接下来,我们看看 ShuffleBlockFetcherIterator 的初始化工作:

    initialize()
    
    private[this] def initialize(): Unit = {
      // Add a task completion callback (called in both success case and failure case) to cleanup.
      context.addTaskCompletionListener(_ => cleanup())
    
      // Split local and remote blocks.
      // 将本地 blocks 和远程 blocks 请求分离开
      val remoteRequests = splitLocalRemoteBlocks()
      // 随机打乱需要进行远程拉取请求的
      // 避免热点问题
      fetchRequests ++= Utils.randomize(remoteRequests)
     
      // 远程拉取
      fetchUpToMaxBytes()
    
      val numFetches = remoteRequests.size - fetchRequests.size
        
      // 本地拉取
      fetchLocalBlocks()
    
    }
    

    到这里,又需要分为三部分进行剖析,分别为 splitLocalRemoteBlocks()、fetchUpToMaxBytes() 和 fetchLocalBlocks():

    • splitLocalRemoteBlocks() 的实现细节:

      private[this] def splitLocalRemoteBlocks(): ArrayBuffer[FetchRequest] = {
        // Make remote requests at most maxBytesInFlight / 5 in length; the reason to keep them
        // smaller than maxBytesInFlight is to allow multiple, parallel fetches from up to 5
        // nodes, rather than blocking on reading output from one node.
        // 将远程请求的数据大小设置为 maxBytesInFlight / 5
        // maxBytesInFlight 上面提到过,为从 Map 端一次拉取的最大数据量
        // 变为 1/5 主要是为了提高并行度,而不是单一的从一个节点上拉取
        // 这个做,可以同时从 5 个节点上拉取,每次只拉取一小部分
        val targetRequestSize = math.max(maxBytesInFlight / 5, 1L)
      
        // 远程请求数组
        val remoteRequests = new ArrayBuffer[FetchRequest]
      
        // Tracks total number of blocks (including zero sized blocks)
        var totalBlocks = 0
        for ((address, blockInfos) <- blocksByAddress) {
          totalBlocks += blockInfos.size
          if (address.executorId == blockManager.blockManagerId.executorId) {
            // 本地节点 
            // 过滤掉块大小为 0 的 block
            localBlocks ++= blockInfos.filter(_._2 != 0).map(_._1)
            numBlocksToFetch += localBlocks.size
          } else {
            val iterator = blockInfos.iterator
            var curRequestSize = 0L
            var curBlocks = new ArrayBuffer[(BlockId, Long)]
            while (iterator.hasNext) {
              val (blockId, size) = iterator.next()
              // Skip empty blocks
              if (size > 0) {
                curBlocks += ((blockId, size))
                remoteBlocks += blockId
                numBlocksToFetch += 1
                curRequestSize += size
              } else if (size < 0) {
                throw new BlockException(blockId, "Negative block size " + size)
              }
              if (curRequestSize >= targetRequestSize ||
                  curBlocks.size >= maxBlocksInFlightPerAddress) {
                // 将满足大小的加入到 remoteRequests 中
                remoteRequests += new FetchRequest(address, curBlocks)
                curBlocks = new ArrayBuffer[(BlockId, Long)]
                curRequestSize = 0
              }
            }
            // 将剩余的构建成一个 FetchRequest 加入到 remoteRequests 中
            if (curBlocks.nonEmpty) {
              remoteRequests += new FetchRequest(address, curBlocks)
            }
          }
        }
        remoteRequests
      }
      

      splitLocalRemoteBlocks() 负责将远程数据和本地数据进行分割,分而治之。

    • fetchUpToMaxBytes() 的实现细节:

      private def fetchUpToMaxBytes(): Unit = {
        // Send fetch requests up to maxBytesInFlight. If you cannot fetch from a remote host
        // immediately, defer the request until the next time it can be processed.
      
        // 处理需要延迟的拉取请求
        if (deferredFetchRequests.nonEmpty) {
          for ((remoteAddress, defReqQueue) <- deferredFetchRequests) {
            while (isRemoteBlockFetchable(defReqQueue) &&
                !isRemoteAddressMaxedOut(remoteAddress, defReqQueue.front)) {
              val request = defReqQueue.dequeue()
              // 发送拉取请求
              send(remoteAddress, request)
              if (defReqQueue.isEmpty) {
                deferredFetchRequests -= remoteAddress
              }
            }
          }
        }
      
        // Process any regular fetch requests if possible.
        while (isRemoteBlockFetchable(fetchRequests)) {
          val request = fetchRequests.dequeue()
          val remoteAddress = request.address
          if (isRemoteAddressMaxedOut(remoteAddress, request)) {
            // 需要请求的地址数超过限制
            val defReqQueue = deferredFetchRequests.getOrElse(remoteAddress, new Queue[FetchRequest]())
            defReqQueue.enqueue(request)
            deferredFetchRequests(remoteAddress) = defReqQueue
          } else {
            // 直接发送请求
            send(remoteAddress, request)
          }
        }
      
        // 处理发送请求
        def send(remoteAddress: BlockManagerId, request: FetchRequest): Unit = {
          sendRequest(request)
          numBlocksInFlightPerAddress(remoteAddress) =
            numBlocksInFlightPerAddress.getOrElse(remoteAddress, 0) + request.blocks.size
        }
      
        def isRemoteBlockFetchable(fetchReqQueue: Queue[FetchRequest]): Boolean = {
          fetchReqQueue.nonEmpty &&
            (bytesInFlight == 0 ||
              (reqsInFlight + 1 <= maxReqsInFlight &&
                bytesInFlight + fetchReqQueue.front.size <= maxBytesInFlight))
        }
      
        // Checks if sending a new fetch request will exceed the max no. of blocks being fetched from a
        // given remote address.
        def isRemoteAddressMaxedOut(remoteAddress: BlockManagerId, request: FetchRequest): Boolean = {
          numBlocksInFlightPerAddress.getOrElse(remoteAddress, 0) + request.blocks.size >
            // 上面提到过
            // 默认为 Int.MaxValue
            // 可通过 spark.reducer.maxBlocksInFlightPerAddress 来设置
            maxBlocksInFlightPerAddress
        }
      }
      

      拉取工作,最终都是由 sendRequest() 方法发出的,我们来看看它的实现细节:

      private[this] def sendRequest(req: FetchRequest) {
          
        bytesInFlight += req.size
        reqsInFlight += 1
      
        // so we can look up the size of each blockID
        val sizeMap = req.blocks.map { case (blockId, size) => (blockId.toString, size) }.toMap
        val remainingBlocks = new HashSet[String]() ++= sizeMap.keys
        val blockIds = req.blocks.map(_._1.toString)
        val address = req.address
      
        // 创建拉取监听器
        val blockFetchingListener = new BlockFetchingListener {
          // 拉取成功
          override def onBlockFetchSuccess(blockId: String, buf: ManagedBuffer): Unit = {
            // Only add the buffer to results queue if the iterator is not zombie,
            // i.e. cleanup() has not been called yet.
            ShuffleBlockFetcherIterator.this.synchronized {
              if (!isZombie) {
                // Increment the ref count because we need to pass this to a different thread.
                // This needs to be released after use.
                buf.retain()
                remainingBlocks -= blockId
                // 将结果添加到 results 中
                results.put(new SuccessFetchResult(BlockId(blockId), address, sizeMap(blockId), buf,
                  remainingBlocks.isEmpty))
              }
            }
            logTrace("Got remote block " + blockId + " after " + Utils.getUsedTimeMs(startTime))
          }
        // 拉取失败
          override def onBlockFetchFailure(blockId: String, e: Throwable): Unit = {
            // 将结果添加到 results 中
            results.put(new FailureFetchResult(BlockId(blockId), address, e))
          }
        }
      
        // Fetch remote shuffle blocks to disk when the request is too large. Since the shuffle data is
        // already encrypted and compressed over the wire(w.r.t. the related configs), we can just fetch
        // the data and write it to file directly.
        if (req.size > maxReqSizeShuffleToMem) {
          // 需要拉取的数据大小无法放到内存中
          // 超过了最大的放置大小
          // 直接写入磁盘
          shuffleClient.fetchBlocks(address.host, address.port, address.executorId, blockIds.toArray,
            blockFetchingListener, this)
        } else {
          // 拉取
          shuffleClient.fetchBlocks(address.host, address.port, address.executorId, blockIds.toArray,
            blockFetchingListener, null)
        }
      }
      

      上面说过 ShuffleBlockFetcherIterator 使用 NettyBlockTransferService 进行通信,所以,我们看看 NettyBlockTransferService.fetchBlocks() 的实现原理:

      override def fetchBlocks(
          host: String,
          port: Int,
          execId: String,
          blockIds: Array[String],
          listener: BlockFetchingListener,
          tempShuffleFileManager: TempShuffleFileManager): Unit = {
          
        try {
          // 支持重试的拉取器
          val blockFetchStarter = new RetryingBlockFetcher.BlockFetchStarter {
            override def createAndStart(blockIds: Array[String], listener: BlockFetchingListener) {
              // 通信端
              val client = clientFactory.createClient(host, port)
              // 真正干活的
              new OneForOneBlockFetcher(client, appId, execId, blockIds, listener,
                transportConf, tempShuffleFileManager).start()
            }
          }
      
          // 获取最大的重试次数
          val maxRetries = transportConf.maxIORetries()
          if (maxRetries > 0) {
            // Note this Fetcher will correctly handle maxRetries == 0; we avoid it just in case there's
            // a bug in this code. We should remove the if statement once we're sure of the stability.
            // 可重试
            new RetryingBlockFetcher(transportConf, blockFetchStarter, blockIds, listener).start()
          } else {
            // 直接调用 blockFetchStarter.blockFetchStarter() 方法
            blockFetchStarter.createAndStart(blockIds, listener)
          }
        } catch {
        // ...
        }
      }
      

      OneForOneBlockFetcher.start() 的实现细节:

      public void start() {
          if (this.blockIds.length == 0) {
              throw new IllegalArgumentException("Zero-sized blockIds array");
          } else {
              this.client.sendRpc(this.openMessage.toByteBuffer(), new RpcResponseCallback() {
                // 拉取成功
                  public void onSuccess(ByteBuffer response) { // 响应数据
                      try {
                        // 数据流管理器
                          OneForOneBlockFetcher.this.streamHandle = (StreamHandle)Decoder.fromByteBuffer(response);
                          // 挨个块遍历
                          for(int i = 0; i < OneForOneBlockFetcher.this.streamHandle.numChunks; ++i) {
                              if (OneForOneBlockFetcher.this.tempShuffleFileManager != null) {
                                // 直接写入磁盘
                                  OneForOneBlockFetcher.this.client.stream(OneForOneStreamManager.genStreamChunkId(OneForOneBlockFetcher.this.streamHandle.streamId, i), OneForOneBlockFetcher.this.new DownloadCallback(i));
                              } else {
                                // 写入内存
                                  OneForOneBlockFetcher.this.client.fetchChunk(OneForOneBlockFetcher.this.streamHandle.streamId, i, OneForOneBlockFetcher.this.chunkCallback);
                              }
                          }
                      } catch (Exception var3) {
                          OneForOneBlockFetcher.logger.error("Failed while starting block fetches after success", var3);
                          OneForOneBlockFetcher.this.failRemainingBlocks(OneForOneBlockFetcher.this.blockIds, var3);
                      }
      
                  }
                
                // 失败
                  public void onFailure(Throwable e) {
                      OneForOneBlockFetcher.logger.error("Failed while starting block fetches", e);
                      OneForOneBlockFetcher.this.failRemainingBlocks(OneForOneBlockFetcher.this.blockIds, e);
                  }
              });
          }
      }
      

      远程拉取工作我们就简单剖析到这里。

    • 相比较远程拉取工作,fetchLocalBlocks() 方法就相对简单的多了:

    private[this] def fetchLocalBlocks() {
      val iter = localBlocks.iterator
      // 迭代
      while (iter.hasNext) {
        val blockId = iter.next()
        try {
          // 获取数据
          val buf = blockManager.getBlockData(blockId)
          shuffleMetrics.incLocalBlocksFetched(1)
          shuffleMetrics.incLocalBytesRead(buf.size)
          // 保留操作
          buf.retain()
          // 加入到 results 中
          results.put(new SuccessFetchResult(blockId, blockManager.blockManagerId, 0, buf, false))
        } catch {
          // ..
        }
      }
    }
    

    ShuffleBlockFetcherIterator 的初始化工作就这么多,简单的总结一下,ShuffleBlockFetcherIterator 使用传递进来的通信框架和要拉取数据的元数据信息,进行远程和本地的数据拉取工作,并将最终结果,存放到 results 中。

    2. aggregatedIter 聚合数据迭代器的生成

    aggregatedIter 是聚合数据的迭代器,也就是,在这步完成的 reduce 端聚合操作:

    // 聚合操作
    // 获取聚合后的迭代器
    val aggregatedIter: Iterator[Product2[K, C]] = if (dep.aggregator.isDefined) {
      if (dep.mapSideCombine) {
        // We are reading values that are already combined
        // 在 map 端聚合过
        // 创建聚合过的数据 key-value 迭代器
        val combinedKeyValuesIterator = interruptibleIter.asInstanceOf[Iterator[(K, C)]]
        // 聚合
        dep.aggregator.get.combineCombinersByKey(combinedKeyValuesIterator, context)
      } else {
        // We don't know the value type, but also don't care -- the dependency *should*
        // have made sure its compatible w/ this aggregator, which will convert the value
        // type to the combined type C
        // 在 map 端没有聚合过
        val keyValuesIterator = interruptibleIter.asInstanceOf[Iterator[(K, Nothing)]]
        dep.aggregator.get.combineValuesByKey(keyValuesIterator, context)
      }
    } else {
      require(!dep.mapSideCombine, "Map-side combine without Aggregator specified!")
      interruptibleIter.asInstanceOf[Iterator[Product2[K, C]]]
    }
    

    处理在 map 端聚合过的 combineCombinersByKey() 方法的实现细节:

    def combineCombinersByKey(
        iter: Iterator[_ <: Product2[K, C]],
        context: TaskContext): Iterator[(K, C)] = {
      val combiners = new ExternalAppendOnlyMap[K, C, C](identity, mergeCombiners, mergeCombiners)
      combiners.insertAll(iter)
      updateMetrics(context, combiners)
      combiners.iterator
    }
    

    处理没在 map 端聚合过的 combineValuesByKey() 的实现细节:

    def combineValuesByKey(
        iter: Iterator[_ <: Product2[K, V]],
        context: TaskContext): Iterator[(K, C)] = {
      // 传入的参数不同
      val combiners = new ExternalAppendOnlyMap[K, V, C](createCombiner, mergeValue, mergeCombiners)
      combiners.insertAll(iter)
      updateMetrics(context, combiners)
      combiners.iterator
    }
    

    基本流程一样,就是创建 ExternalAppendOnlyMap 是传递参数不同。

    ExternalAppendOnlyMap 类结构

    ExternalAppendOnlyMap.insertAll() 的实现细节:

    // 与 Shuffle-Writer 的 ExternalSorter.insertAll() 实现类似
    def insertAll(entries: Iterator[Product2[K, V]]): Unit = {
      if (currentMap == null) {
        throw new IllegalStateException(
          "Cannot insert new elements into a map after calling iterator")
      }
      // An update function for the map that we reuse across entries to avoid allocating
      // a new closure each time
      var curEntry: Product2[K, V] = null
      // update 函数,Shuffle-Writer 提到过
      val update: (Boolean, C) => C = (hadVal, oldVal) => {
        if (hadVal) mergeValue(oldVal, curEntry._2) else createCombiner(curEntry._2)
      }
    
      while (entries.hasNext) {
        curEntry = entries.next()
        val estimatedSize = currentMap.estimateSize()
        if (estimatedSize > _peakMemoryUsedBytes) {
          _peakMemoryUsedBytes = estimatedSize
        }
        // 是否需要溢写
        if (maybeSpill(currentMap, estimatedSize)) {
          currentMap = new SizeTrackingAppendOnlyMap[K, C]
        }
        // 值聚合或初次创建
        currentMap.changeValue(curEntry._1, update)
        addElementsRead()
      }
    }
    

    我们再看看 ExternalAppendOnlyMap.iterator 的实现细节:

    override def iterator: Iterator[(K, C)] = {
      if (currentMap == null) {
        throw new IllegalStateException(
          "ExternalAppendOnlyMap.iterator is destructive and should only be called once.")
      }
      // 根据时候产生过溢写文件,创建不同的迭代器
      if (spilledMaps.isEmpty) {
        // 没有产生过溢写文件
        // 数据都在内存中
        CompletionIterator[(K, C), Iterator[(K, C)]](
          destructiveIterator(currentMap.iterator), freeCurrentMap())
      } else {
        // 产生过溢写文件,创建外部迭代器
        new ExternalIterator()
      }
    }
    

    我们看下实例化 ExternalIterator 都做了哪些工作:

    // A queue that maintains a buffer for each stream we are currently merging
    // This queue maintains the invariant that it only contains non-empty buffers
    private val mergeHeap = new mutable.PriorityQueue[StreamBuffer]
    
    // Input streams are derived both from the in-memory map and spilled maps on disk
    // The in-memory map is sorted in place, while the spilled maps are already in sorted order
    // 按照 key 进行排序
    private val sortedMap = CompletionIterator[(K, C), Iterator[(K, C)]](destructiveIterator(
      currentMap.destructiveSortedIterator(keyComparator)), freeCurrentMap())
    // 将缓存中的数据和文件中的数据的 iter 进行合并
    private val inputStreams = (Seq(sortedMap) ++ spilledMaps).map(it => it.buffered)
    // 将数据读取出来,放到 mergeHeap 中
    inputStreams.foreach { it =>
      val kcPairs = new ArrayBuffer[(K, C)]
      readNextHashCode(it, kcPairs)
      if (kcPairs.length > 0) {
        mergeHeap.enqueue(new StreamBuffer(it, kcPairs))
      }
    }
    

    这样我们就获取到了聚合数的迭代器了。

    3. 排序数据迭代器的生成(如果有需要)

    排序迭代器是通过 ExternalSorter 来生成的,ExternalSorter 在 Shuffle-Write 中剖析过:

    dep.keyOrdering match {
      case Some(keyOrd: Ordering[K]) =>
        // Create an ExternalSorter to sort the data.
        val sorter =
          new ExternalSorter[K, C, C](context, ordering = Some(keyOrd), serializer = dep.serializer)
        // 将数据插入到 ExternalSorter 中
        // 在 Shuffle-Writer 中剖析过
        sorter.insertAll(aggregatedIter)
        context.taskMetrics().incMemoryBytesSpilled(sorter.memoryBytesSpilled)
        context.taskMetrics().incDiskBytesSpilled(sorter.diskBytesSpilled)
        context.taskMetrics().incPeakExecutionMemory(sorter.peakMemoryUsedBytes)
        // 还是个迭代器
        CompletionIterator[Product2[K, C], Iterator[Product2[K, C]]](sorter.iterator, sorter.stop())
      case None =>
        aggregatedIter
    }
    

    在这里,我们看戏 ExternalSorter.iterator() 的实现细节:

    def iterator: Iterator[Product2[K, C]] = {
      isShuffleSort = false
      partitionedIterator.flatMap(pair => pair._2)
    }
    

    partitionedIterator() 的实现细节:

    def partitionedIterator: Iterator[(Int, Iterator[Product2[K, C]])] = {
      val usingMap = aggregator.isDefined
      val collection: WritablePartitionedPairCollection[K, C] = if (usingMap) map else buffer
      if (spills.isEmpty) {
        // 没有溢写文件
        if (!ordering.isDefined) {
          // 按分区进行分组
          groupByPartition(destructiveIterator(collection.partitionedDestructiveSortedIterator(None)))
        } else {
          // 按分区进行分组
          groupByPartition(destructiveIterator(
            collection.partitionedDestructiveSortedIterator(Some(keyComparator))))
        }
      } else {
        // 将内存和文件中的数据进行合并
        merge(spills, destructiveIterator(
          collection.partitionedDestructiveSortedIterator(comparator)))
      }
    }
    

    到这里我们就不往下看了,可以参考 Shuffle-Write.

    相关文章

      网友评论

          本文标题:Spark 源码浅析之 Shuffle 读部分

          本文链接:https://www.haomeiwen.com/subject/fmrmeqtx.html