美文网首页
Spark源码:运行Task

Spark源码:运行Task

作者: Jorvi | 来源:发表于2019-12-17 15:39 被阅读0次

    源码目录


    Driver 端调用 launchTasks 方法发送 LaunchTask 消息给 Executor,用于启动 Task。

    1 接收处理LaunchTask消息

    • 进入org.apache.spark.executor.CoarseGrainedExecutorBackend.scala
      override def receive: PartialFunction[Any, Unit] = {
        case LaunchTask(data) =>
          if (executor == null) {
            exitExecutor(1, "Received LaunchTask command but executor was null")
          } else {
            val taskDesc = TaskDescription.decode(data.value)
            logInfo("Got assigned task " + taskDesc.taskId)
            executor.launchTask(this, taskDesc)
          }
      }
    
    1. 反序列化TaskDescription;
    2. 调用 Executor.launchTask 方法提交Task。
    • 进入org.apache.spark.executor.Executor.scala
      def launchTask(context: ExecutorBackend, taskDescription: TaskDescription): Unit = {
        val tr = new TaskRunner(context, taskDescription)
        runningTasks.put(taskDescription.taskId, tr)
        threadPool.execute(tr)
      }
    
    1. 基于 TaskDescription,创建 TaskRunner;
    2. 将 TaskRunner 放入Executor.runningTasks 中;
    3. 执行 TaskRunner。

    2 运行Task

    • 进入org.apache.spark.executor.Executor.TaskRunner.scala
      class TaskRunner(
          execBackend: ExecutorBackend,
          private val taskDescription: TaskDescription)
        extends Runnable {
    
        override def run(): Unit = {
          threadId = Thread.currentThread.getId
          Thread.currentThread.setName(threadName)
          val threadMXBean = ManagementFactory.getThreadMXBean
          val taskMemoryManager = new TaskMemoryManager(env.memoryManager, taskId)
          val deserializeStartTime = System.currentTimeMillis()
          val deserializeStartCpuTime = if (threadMXBean.isCurrentThreadCpuTimeSupported) {
            threadMXBean.getCurrentThreadCpuTime
          } else 0L
          Thread.currentThread.setContextClassLoader(replClassLoader)
          val ser = env.closureSerializer.newInstance()
          logInfo(s"Running $taskName (TID $taskId)")
          execBackend.statusUpdate(taskId, TaskState.RUNNING, EMPTY_BYTE_BUFFER)
          var taskStartTime: Long = 0
          var taskStartCpu: Long = 0
          startGCTime = computeTotalGcTime()
    
          try {
            // Must be set before updateDependencies() is called, in case fetching dependencies
            // requires access to properties contained within (e.g. for access control).
            Executor.taskDeserializationProps.set(taskDescription.properties)
    
            updateDependencies(taskDescription.addedFiles, taskDescription.addedJars)
            task = ser.deserialize[Task[Any]](
              taskDescription.serializedTask, Thread.currentThread.getContextClassLoader)
            task.localProperties = taskDescription.properties
            task.setTaskMemoryManager(taskMemoryManager)
    
            // If this task has been killed before we deserialized it, let's quit now. Otherwise,
            // continue executing the task.
            val killReason = reasonIfKilled
            if (killReason.isDefined) {
              // Throw an exception rather than returning, because returning within a try{} block
              // causes a NonLocalReturnControl exception to be thrown. The NonLocalReturnControl
              // exception will be caught by the catch block, leading to an incorrect ExceptionFailure
              // for the task.
              throw new TaskKilledException(killReason.get)
            }
    
            // The purpose of updating the epoch here is to invalidate executor map output status cache
            // in case FetchFailures have occurred. In local mode `env.mapOutputTracker` will be
            // MapOutputTrackerMaster and its cache invalidation is not based on epoch numbers so
            // we don't need to make any special calls here.
            if (!isLocal) {
              logDebug("Task " + taskId + "'s epoch is " + task.epoch)
              env.mapOutputTracker.asInstanceOf[MapOutputTrackerWorker].updateEpoch(task.epoch)
            }
    
            // Run the actual task and measure its runtime.
            taskStartTime = System.currentTimeMillis()
            taskStartCpu = if (threadMXBean.isCurrentThreadCpuTimeSupported) {
              threadMXBean.getCurrentThreadCpuTime
            } else 0L
            var threwException = true
            val value = Utils.tryWithSafeFinally {
              val res = task.run(
                taskAttemptId = taskId,
                attemptNumber = taskDescription.attemptNumber,
                metricsSystem = env.metricsSystem)
              threwException = false
              res
            } {
              val releasedLocks = env.blockManager.releaseAllLocksForTask(taskId)
              val freedMemory = taskMemoryManager.cleanUpAllAllocatedMemory()
    
              if (freedMemory > 0 && !threwException) {
                val errMsg = s"Managed memory leak detected; size = $freedMemory bytes, TID = $taskId"
                if (conf.getBoolean("spark.unsafe.exceptionOnMemoryLeak", false)) {
                  throw new SparkException(errMsg)
                } else {
                  logWarning(errMsg)
                }
              }
    
              if (releasedLocks.nonEmpty && !threwException) {
                val errMsg =
                  s"${releasedLocks.size} block locks were not released by TID = $taskId:\n" +
                    releasedLocks.mkString("[", ", ", "]")
                if (conf.getBoolean("spark.storage.exceptionOnPinLeak", false)) {
                  throw new SparkException(errMsg)
                } else {
                  logInfo(errMsg)
                }
              }
            }
            task.context.fetchFailed.foreach { fetchFailure =>
              // uh-oh.  it appears the user code has caught the fetch-failure without throwing any
              // other exceptions.  Its *possible* this is what the user meant to do (though highly
              // unlikely).  So we will log an error and keep going.
              logError(s"TID ${taskId} completed successfully though internally it encountered " +
                s"unrecoverable fetch failures!  Most likely this means user code is incorrectly " +
                s"swallowing Spark's internal ${classOf[FetchFailedException]}", fetchFailure)
            }
            val taskFinish = System.currentTimeMillis()
            val taskFinishCpu = if (threadMXBean.isCurrentThreadCpuTimeSupported) {
              threadMXBean.getCurrentThreadCpuTime
            } else 0L
    
            // If the task has been killed, let's fail it.
            task.context.killTaskIfInterrupted()
    
            val resultSer = env.serializer.newInstance()
            val beforeSerialization = System.currentTimeMillis()
            val valueBytes = resultSer.serialize(value)
            val afterSerialization = System.currentTimeMillis()
    
            // Deserialization happens in two parts: first, we deserialize a Task object, which
            // includes the Partition. Second, Task.run() deserializes the RDD and function to be run.
            task.metrics.setExecutorDeserializeTime(
              (taskStartTime - deserializeStartTime) + task.executorDeserializeTime)
            task.metrics.setExecutorDeserializeCpuTime(
              (taskStartCpu - deserializeStartCpuTime) + task.executorDeserializeCpuTime)
            // We need to subtract Task.run()'s deserialization time to avoid double-counting
            task.metrics.setExecutorRunTime((taskFinish - taskStartTime) - task.executorDeserializeTime)
            task.metrics.setExecutorCpuTime(
              (taskFinishCpu - taskStartCpu) - task.executorDeserializeCpuTime)
            task.metrics.setJvmGCTime(computeTotalGcTime() - startGCTime)
            task.metrics.setResultSerializationTime(afterSerialization - beforeSerialization)
    
            // Expose task metrics using the Dropwizard metrics system.
            // Update task metrics counters
            executorSource.METRIC_CPU_TIME.inc(task.metrics.executorCpuTime)
            executorSource.METRIC_RUN_TIME.inc(task.metrics.executorRunTime)
            executorSource.METRIC_JVM_GC_TIME.inc(task.metrics.jvmGCTime)
            executorSource.METRIC_DESERIALIZE_TIME.inc(task.metrics.executorDeserializeTime)
            executorSource.METRIC_DESERIALIZE_CPU_TIME.inc(task.metrics.executorDeserializeCpuTime)
            executorSource.METRIC_RESULT_SERIALIZE_TIME.inc(task.metrics.resultSerializationTime)
            executorSource.METRIC_SHUFFLE_FETCH_WAIT_TIME
              .inc(task.metrics.shuffleReadMetrics.fetchWaitTime)
            executorSource.METRIC_SHUFFLE_WRITE_TIME.inc(task.metrics.shuffleWriteMetrics.writeTime)
            executorSource.METRIC_SHUFFLE_TOTAL_BYTES_READ
              .inc(task.metrics.shuffleReadMetrics.totalBytesRead)
            executorSource.METRIC_SHUFFLE_REMOTE_BYTES_READ
              .inc(task.metrics.shuffleReadMetrics.remoteBytesRead)
            executorSource.METRIC_SHUFFLE_REMOTE_BYTES_READ_TO_DISK
              .inc(task.metrics.shuffleReadMetrics.remoteBytesReadToDisk)
            executorSource.METRIC_SHUFFLE_LOCAL_BYTES_READ
              .inc(task.metrics.shuffleReadMetrics.localBytesRead)
            executorSource.METRIC_SHUFFLE_RECORDS_READ
              .inc(task.metrics.shuffleReadMetrics.recordsRead)
            executorSource.METRIC_SHUFFLE_REMOTE_BLOCKS_FETCHED
              .inc(task.metrics.shuffleReadMetrics.remoteBlocksFetched)
            executorSource.METRIC_SHUFFLE_LOCAL_BLOCKS_FETCHED
              .inc(task.metrics.shuffleReadMetrics.localBlocksFetched)
            executorSource.METRIC_SHUFFLE_BYTES_WRITTEN
              .inc(task.metrics.shuffleWriteMetrics.bytesWritten)
            executorSource.METRIC_SHUFFLE_RECORDS_WRITTEN
              .inc(task.metrics.shuffleWriteMetrics.recordsWritten)
            executorSource.METRIC_INPUT_BYTES_READ
              .inc(task.metrics.inputMetrics.bytesRead)
            executorSource.METRIC_INPUT_RECORDS_READ
              .inc(task.metrics.inputMetrics.recordsRead)
            executorSource.METRIC_OUTPUT_BYTES_WRITTEN
              .inc(task.metrics.outputMetrics.bytesWritten)
            executorSource.METRIC_OUTPUT_RECORDS_WRITTEN
              .inc(task.metrics.outputMetrics.recordsWritten)
            executorSource.METRIC_RESULT_SIZE.inc(task.metrics.resultSize)
            executorSource.METRIC_DISK_BYTES_SPILLED.inc(task.metrics.diskBytesSpilled)
            executorSource.METRIC_MEMORY_BYTES_SPILLED.inc(task.metrics.memoryBytesSpilled)
    
            // Note: accumulator updates must be collected after TaskMetrics is updated
            val accumUpdates = task.collectAccumulatorUpdates()
            // TODO: do not serialize value twice
            val directResult = new DirectTaskResult(valueBytes, accumUpdates)
            val serializedDirectResult = ser.serialize(directResult)
            val resultSize = serializedDirectResult.limit()
    
            // directSend = sending directly back to the driver
            val serializedResult: ByteBuffer = {
              if (maxResultSize > 0 && resultSize > maxResultSize) {
                logWarning(s"Finished $taskName (TID $taskId). Result is larger than maxResultSize " +
                  s"(${Utils.bytesToString(resultSize)} > ${Utils.bytesToString(maxResultSize)}), " +
                  s"dropping it.")
                ser.serialize(new IndirectTaskResult[Any](TaskResultBlockId(taskId), resultSize))
              } else if (resultSize > maxDirectResultSize) {
                val blockId = TaskResultBlockId(taskId)
                env.blockManager.putBytes(
                  blockId,
                  new ChunkedByteBuffer(serializedDirectResult.duplicate()),
                  StorageLevel.MEMORY_AND_DISK_SER)
                logInfo(
                  s"Finished $taskName (TID $taskId). $resultSize bytes result sent via BlockManager)")
                ser.serialize(new IndirectTaskResult[Any](blockId, resultSize))
              } else {
                logInfo(s"Finished $taskName (TID $taskId). $resultSize bytes result sent to driver")
                serializedDirectResult
              }
            }
    
            setTaskFinishedAndClearInterruptStatus()
            execBackend.statusUpdate(taskId, TaskState.FINISHED, serializedResult)
    
          } catch {
            case t: TaskKilledException =>
              logInfo(s"Executor killed $taskName (TID $taskId), reason: ${t.reason}")
    
              val (accums, accUpdates) = collectAccumulatorsAndResetStatusOnFailure(taskStartTime)
              val serializedTK = ser.serialize(TaskKilled(t.reason, accUpdates, accums))
              execBackend.statusUpdate(taskId, TaskState.KILLED, serializedTK)
    
            case _: InterruptedException | NonFatal(_) if
                task != null && task.reasonIfKilled.isDefined =>
              val killReason = task.reasonIfKilled.getOrElse("unknown reason")
              logInfo(s"Executor interrupted and killed $taskName (TID $taskId), reason: $killReason")
    
              val (accums, accUpdates) = collectAccumulatorsAndResetStatusOnFailure(taskStartTime)
              val serializedTK = ser.serialize(TaskKilled(killReason, accUpdates, accums))
              execBackend.statusUpdate(taskId, TaskState.KILLED, serializedTK)
    
            case t: Throwable if hasFetchFailure && !Utils.isFatalError(t) =>
              val reason = task.context.fetchFailed.get.toTaskFailedReason
              if (!t.isInstanceOf[FetchFailedException]) {
                // there was a fetch failure in the task, but some user code wrapped that exception
                // and threw something else.  Regardless, we treat it as a fetch failure.
                val fetchFailedCls = classOf[FetchFailedException].getName
                logWarning(s"TID ${taskId} encountered a ${fetchFailedCls} and " +
                  s"failed, but the ${fetchFailedCls} was hidden by another " +
                  s"exception.  Spark is handling this like a fetch failure and ignoring the " +
                  s"other exception: $t")
              }
              setTaskFinishedAndClearInterruptStatus()
              execBackend.statusUpdate(taskId, TaskState.FAILED, ser.serialize(reason))
    
            case CausedBy(cDE: CommitDeniedException) =>
              val reason = cDE.toTaskCommitDeniedReason
              setTaskFinishedAndClearInterruptStatus()
              execBackend.statusUpdate(taskId, TaskState.KILLED, ser.serialize(reason))
    
            case t: Throwable =>
              // Attempt to exit cleanly by informing the driver of our failure.
              // If anything goes wrong (or this was a fatal exception), we will delegate to
              // the default uncaught exception handler, which will terminate the Executor.
              logError(s"Exception in $taskName (TID $taskId)", t)
    
              // SPARK-20904: Do not report failure to driver if if happened during shut down. Because
              // libraries may set up shutdown hooks that race with running tasks during shutdown,
              // spurious failures may occur and can result in improper accounting in the driver (e.g.
              // the task failure would not be ignored if the shutdown happened because of premption,
              // instead of an app issue).
              if (!ShutdownHookManager.inShutdown()) {
                val (accums, accUpdates) = collectAccumulatorsAndResetStatusOnFailure(taskStartTime)
    
                val serializedTaskEndReason = {
                  try {
                    ser.serialize(new ExceptionFailure(t, accUpdates).withAccums(accums))
                  } catch {
                    case _: NotSerializableException =>
                      // t is not serializable so just send the stacktrace
                      ser.serialize(new ExceptionFailure(t, accUpdates, false).withAccums(accums))
                  }
                }
                setTaskFinishedAndClearInterruptStatus()
                execBackend.statusUpdate(taskId, TaskState.FAILED, serializedTaskEndReason)
              } else {
                logInfo("Not reporting error to driver during JVM shutdown.")
              }
    
              // Don't forcibly exit unless the exception was inherently fatal, to avoid
              // stopping other tasks unnecessarily.
              if (!t.isInstanceOf[SparkOutOfMemoryError] && Utils.isFatalError(t)) {
                uncaughtExceptionHandler.uncaughtException(Thread.currentThread(), t)
              }
          } finally {
            runningTasks.remove(taskId)
          }
        }
      }
    
    1. 准备工作:创建TaskMemoryManager、获取Serializer实例、更新Task状态为RUNNING等;
    2. 设置 Executor.taskDeserializationProps 为 TaskDescription.properties;
    3. 调用 updateDependencies 方法更新丢失的或新增的dependencies;
    4. 反序列化 TaskDescription.serializedTask 得到 Task;
    5. 设置 Task 的 TaskMemoryManager;
    6. 判断 Task 有没有被 kill,如果已经被 kill 了,则抛出 TaskKilledException 异常,结束运行Task;
    7. 调用 task.run() 方法实际运行 Task 得到结果;
    8. 释放 Task 上的锁;
    9. 清理 Task 被分配的内存;
    10. 序列化结果值;
    11. 往任务度量 TaskMetrics 里设置执行时间、GC时间、结果值序列化时间等指标;
    12. 调用 task.collectAccumulatorUpdates() 方法收集 Task 里使用的累加器值(执行时间、GC时间、记录数等);
    13. 封装结果值和累加器值为 DirectTaskResult,序列化 DirectTaskResult 为 serializedDirectResult;
    14. 如果 resultSize 大于 maxResultSize(spark.driver.maxResultSize配置),则打印警告日志,不保存 serializedDirectResult,序列化一个 IndirectTaskResult 作为最终结果(serializedResult);
    15. 如果 resultSize 大于 maxDirectResultSize(spark.task.maxDirectResultSize 和 spark.rpc.message.maxSize 最小值决定),则以 MEMORY_AND_DISK_SER 级别将 serializedDirectResult 保存到 BlockManager 中,序列化一个 IndirectTaskResult 作为最终结果(serializedResult);
    16. 其余情况,则以 serializedDirectResult 作为最终结果(serializedResult);
    17. CoarseGrainedExecutorBackend.statusUpdate(taskId, TaskState.FINISHED, serializedResult) 更新 Task 状态为完成,返回最终结果给 Driver;
    18. 从 Executor.runningTasks 中移除该完成 Task 对应的 taskId。

    3 Task计算过程

    调用 Task.run 方法实际运行 Task。

    • 进入org.apache.spark.scheduler.Task.scala
      final def run(
          taskAttemptId: Long,
          attemptNumber: Int,
          metricsSystem: MetricsSystem): T = {
        SparkEnv.get.blockManager.registerTask(taskAttemptId)
        // TODO SPARK-24874 Allow create BarrierTaskContext based on partitions, instead of whether
        // the stage is barrier.
        val taskContext = new TaskContextImpl(
          stageId,
          stageAttemptId, // stageAttemptId and stageAttemptNumber are semantically equal
          partitionId,
          taskAttemptId,
          attemptNumber,
          taskMemoryManager,
          localProperties,
          metricsSystem,
          metrics)
    
        context = if (isBarrier) {
          new BarrierTaskContext(taskContext)
        } else {
          taskContext
        }
    
        InputFileBlockHolder.initialize()
        TaskContext.setTaskContext(context)
        taskThread = Thread.currentThread()
    
        if (_reasonIfKilled != null) {
          kill(interruptThread = false, _reasonIfKilled)
        }
    
        new CallerContext(
          "TASK",
          SparkEnv.get.conf.get(APP_CALLER_CONTEXT),
          appId,
          appAttemptId,
          jobId,
          Option(stageId),
          Option(stageAttemptId),
          Option(taskAttemptId),
          Option(attemptNumber)).setCurrentContext()
    
        try {
          runTask(context)
        } catch {
          case e: Throwable =>
            // Catch all errors; run task failure callbacks, and rethrow the exception.
            try {
              context.markTaskFailed(e)
            } catch {
              case t: Throwable =>
                e.addSuppressed(t)
            }
            context.markTaskCompleted(Some(e))
            throw e
        } finally {
          try {
            // Call the task completion callbacks. If "markTaskCompleted" is called twice, the second
            // one is no-op.
            context.markTaskCompleted(None)
          } finally {
            try {
              Utils.tryLogNonFatalError {
                // Release memory used by this thread for unrolling blocks
                SparkEnv.get.blockManager.memoryStore.releaseUnrollMemoryForThisTask(MemoryMode.ON_HEAP)
                SparkEnv.get.blockManager.memoryStore.releaseUnrollMemoryForThisTask(
                  MemoryMode.OFF_HEAP)
                // Notify any tasks waiting for execution memory to be freed to wake up and try to
                // acquire memory again. This makes impossible the scenario where a task sleeps forever
                // because there are no other tasks left to notify it. Since this is safe to do but may
                // not be strictly necessary, we should revisit whether we can remove this in the
                // future.
                val memoryManager = SparkEnv.get.memoryManager
                memoryManager.synchronized { memoryManager.notifyAll() }
              }
            } finally {
              // Though we unset the ThreadLocal here, the context member variable itself is still
              // queried directly in the TaskRunner to check for FetchFailedExceptions.
              TaskContext.unset()
              InputFileBlockHolder.unset()
            }
          }
        }
      }
    
    1. 创建 TaskContextImpl;
    2. 调用 runTask 方法。

    Task 是一抽象类,其实现类有 ShuffleMapTask 和 ResultTask 类

    • 进入org.apache.spark.scheduler.ShuffleMapTask.scala
      override def runTask(context: TaskContext): MapStatus = {
        // Deserialize the RDD using the broadcast variable.
        val threadMXBean = ManagementFactory.getThreadMXBean
        val deserializeStartTime = System.currentTimeMillis()
        val deserializeStartCpuTime = if (threadMXBean.isCurrentThreadCpuTimeSupported) {
          threadMXBean.getCurrentThreadCpuTime
        } else 0L
        val ser = SparkEnv.get.closureSerializer.newInstance()
        val (rdd, dep) = ser.deserialize[(RDD[_], ShuffleDependency[_, _, _])](
          ByteBuffer.wrap(taskBinary.value), Thread.currentThread.getContextClassLoader)
        _executorDeserializeTime = System.currentTimeMillis() - deserializeStartTime
        _executorDeserializeCpuTime = if (threadMXBean.isCurrentThreadCpuTimeSupported) {
          threadMXBean.getCurrentThreadCpuTime - deserializeStartCpuTime
        } else 0L
    
        var writer: ShuffleWriter[Any, Any] = null
        try {
          val manager = SparkEnv.get.shuffleManager
          writer = manager.getWriter[Any, Any](dep.shuffleHandle, partitionId, context)
          writer.write(rdd.iterator(partition, context).asInstanceOf[Iterator[_ <: Product2[Any, Any]]])
          writer.stop(success = true).get
        } catch {
          case e: Exception =>
            try {
              if (writer != null) {
                writer.stop(success = false)
              }
            } catch {
              case e: Exception =>
                log.debug("Could not stop writer", e)
            }
            throw e
        }
      }
    
    1. 实例化 Serializer;
    2. 反序列化 Task,得到 (RDD, ShuffleDependency);
    3. 获取 ShuffleManager,获取 ShuffleWriter;
    4. 调用 ShuffleWriter.write 方法。
    • 进入org.apache.spark.scheduler.ResultTask.scala
      override def runTask(context: TaskContext): U = {
        // Deserialize the RDD and the func using the broadcast variables.
        val threadMXBean = ManagementFactory.getThreadMXBean
        val deserializeStartTime = System.currentTimeMillis()
        val deserializeStartCpuTime = if (threadMXBean.isCurrentThreadCpuTimeSupported) {
          threadMXBean.getCurrentThreadCpuTime
        } else 0L
        val ser = SparkEnv.get.closureSerializer.newInstance()
        val (rdd, func) = ser.deserialize[(RDD[T], (TaskContext, Iterator[T]) => U)](
          ByteBuffer.wrap(taskBinary.value), Thread.currentThread.getContextClassLoader)
        _executorDeserializeTime = System.currentTimeMillis() - deserializeStartTime
        _executorDeserializeCpuTime = if (threadMXBean.isCurrentThreadCpuTimeSupported) {
          threadMXBean.getCurrentThreadCpuTime - deserializeStartCpuTime
        } else 0L
    
        func(context, rdd.iterator(partition, context))
      }
    
    1. 实例化 Serializer;
    2. 反序列化 Task,得到 (RDD, func);
    3. 调用 func 执行。

    4 返回结果和状态给Driver

    • 进入org.apache.spark.executor.CoarseGrainedExecutorBackend.scala
      override def statusUpdate(taskId: Long, state: TaskState, data: ByteBuffer) {
        val msg = StatusUpdate(executorId, taskId, state, data)
        driver match {
          case Some(driverRef) => driverRef.send(msg)
          case None => logWarning(s"Drop $msg because has not yet connected to driver")
        }
      }
    
    1. 将序列化后的结果和TaskState等封装成StatusUpdate;
    2. 将 StatusUpdate 发送给 Driver。
    • 进入org.apache.spark.scheduler.cluster.CoarseGrainedSchedulerBackend.DriverEndpoint.scala
        override def receive: PartialFunction[Any, Unit] = {
          case StatusUpdate(executorId, taskId, state, data) =>
            scheduler.statusUpdate(taskId, state, data.value)
            if (TaskState.isFinished(state)) {
              executorDataMap.get(executorId) match {
                case Some(executorInfo) =>
                  executorInfo.freeCores += scheduler.CPUS_PER_TASK
                  makeOffers(executorId)
                case None =>
                  // Ignoring the update since we don't know about the executor.
                  logWarning(s"Ignored task status update ($taskId state $state) " +
                    s"from unknown executor with ID $executorId")
              }
            }
        }
    
    1. 调用 TaskSchedulerImpl.statusUpdate 方法;
    2. 如果 Task 的状态为 FINISHED,基于 executorId 从 CoarseGrainedSchedulerBackend.executorDataMap 中取出相应的 executorData,更新该 executorData 的 freeCores,调用 makeOffers(executorId) 方法为该 Executor 重新规划资源;
    • 进入org.apache.spark.scheduler.TaskSchedulerImp.scala
      def statusUpdate(tid: Long, state: TaskState, serializedData: ByteBuffer) {
        var failedExecutor: Option[String] = None
        var reason: Option[ExecutorLossReason] = None
        synchronized {
          try {
            Option(taskIdToTaskSetManager.get(tid)) match {
              case Some(taskSet) =>
                if (state == TaskState.LOST) {
                  // TaskState.LOST is only used by the deprecated Mesos fine-grained scheduling mode,
                  // where each executor corresponds to a single task, so mark the executor as failed.
                  val execId = taskIdToExecutorId.getOrElse(tid, throw new IllegalStateException(
                    "taskIdToTaskSetManager.contains(tid) <=> taskIdToExecutorId.contains(tid)"))
                  if (executorIdToRunningTaskIds.contains(execId)) {
                    reason = Some(
                      SlaveLost(s"Task $tid was lost, so marking the executor as lost as well."))
                    removeExecutor(execId, reason.get)
                    failedExecutor = Some(execId)
                  }
                }
                if (TaskState.isFinished(state)) {
                  cleanupTaskState(tid)
                  taskSet.removeRunningTask(tid)
                  if (state == TaskState.FINISHED) {
                    taskResultGetter.enqueueSuccessfulTask(taskSet, tid, serializedData)
                  } else if (Set(TaskState.FAILED, TaskState.KILLED, TaskState.LOST).contains(state)) {
                    taskResultGetter.enqueueFailedTask(taskSet, tid, state, serializedData)
                  }
                }
              case None =>
                logError(
                  ("Ignoring update with state %s for TID %s because its task set is gone (this is " +
                    "likely the result of receiving duplicate task finished status updates) or its " +
                    "executor has been marked as failed.")
                    .format(state, tid))
            }
          } catch {
            case e: Exception => logError("Exception in statusUpdate", e)
          }
        }
        // Update the DAGScheduler without holding a lock on this, since that can deadlock
        if (failedExecutor.isDefined) {
          assert(reason.isDefined)
          dagScheduler.executorLost(failedExecutor.get, reason.get)
          backend.reviveOffers()
        }
      }
    
    1. 如果 TaskState 为 FINISHED,则调用 TaskResultGetter.enqueueSuccessfulTask 方法获取结果;
    2. 如果 TaskState 为 FAILED,则调用 TaskResultGetter.enqueueFailedTask 方法获取结果。
    • 进入org.apache.spark.scheduler.TaskResultGetter.scala
      def enqueueSuccessfulTask(
          taskSetManager: TaskSetManager,
          tid: Long,
          serializedData: ByteBuffer): Unit = {
        getTaskResultExecutor.execute(new Runnable {
          override def run(): Unit = Utils.logUncaughtExceptions {
            try {
              val (result, size) = serializer.get().deserialize[TaskResult[_]](serializedData) match {
                case directResult: DirectTaskResult[_] =>
                  if (!taskSetManager.canFetchMoreResults(serializedData.limit())) {
                    return
                  }
                  // deserialize "value" without holding any lock so that it won't block other threads.
                  // We should call it here, so that when it's called again in
                  // "TaskSetManager.handleSuccessfulTask", it does not need to deserialize the value.
                  directResult.value(taskResultSerializer.get())
                  (directResult, serializedData.limit())
                case IndirectTaskResult(blockId, size) =>
                  if (!taskSetManager.canFetchMoreResults(size)) {
                    // dropped by executor if size is larger than maxResultSize
                    sparkEnv.blockManager.master.removeBlock(blockId)
                    return
                  }
                  logDebug("Fetching indirect task result for TID %s".format(tid))
                  scheduler.handleTaskGettingResult(taskSetManager, tid)
                  val serializedTaskResult = sparkEnv.blockManager.getRemoteBytes(blockId)
                  if (!serializedTaskResult.isDefined) {
                    /* We won't be able to get the task result if the machine that ran the task failed
                     * between when the task ended and when we tried to fetch the result, or if the
                     * block manager had to flush the result. */
                    scheduler.handleFailedTask(
                      taskSetManager, tid, TaskState.FINISHED, TaskResultLost)
                    return
                  }
                  val deserializedResult = serializer.get().deserialize[DirectTaskResult[_]](
                    serializedTaskResult.get.toByteBuffer)
                  // force deserialization of referenced value
                  deserializedResult.value(taskResultSerializer.get())
                  sparkEnv.blockManager.master.removeBlock(blockId)
                  (deserializedResult, size)
              }
    
              // Set the task result size in the accumulator updates received from the executors.
              // We need to do this here on the driver because if we did this on the executors then
              // we would have to serialize the result again after updating the size.
              result.accumUpdates = result.accumUpdates.map { a =>
                if (a.name == Some(InternalAccumulator.RESULT_SIZE)) {
                  val acc = a.asInstanceOf[LongAccumulator]
                  assert(acc.sum == 0L, "task result size should not have been set on the executors")
                  acc.setValue(size.toLong)
                  acc
                } else {
                  a
                }
              }
    
              scheduler.handleSuccessfulTask(taskSetManager, tid, result)
            } catch {
              case cnf: ClassNotFoundException =>
                val loader = Thread.currentThread.getContextClassLoader
                taskSetManager.abort("ClassNotFound with classloader: " + loader)
              // Matching NonFatal so we don't catch the ControlThrowable from the "return" above.
              case NonFatal(ex) =>
                logError("Exception while getting task result", ex)
                taskSetManager.abort("Exception while getting task result: %s".format(ex))
            }
          }
        })
      }
    
    1. 启一个新线程;
    2. 在新线程中反序列化 TaskResult;
    3. 如果反序列化后的结果匹配 DirectTaskResult,则直接使用该结果;
    4. 如果反序列化后的结果匹配 IndirectTaskResult,则判断结果大小是否超出spark.driver.maxResultSize限制,如果超过则结果为空;如果没有超过则从 BlockManager 中获取结果并反序列化后作为结果;
    5. 更新结果累加器中的RESULT_SIZE;
    6. 调用 handleSuccessfulTask 方法。
    • 进入org.apache.spark.scheduler.TaskSetManager.scala
      /**
       * Marks a task as successful and notifies the DAGScheduler that the task has ended.
       */
      def handleSuccessfulTask(tid: Long, result: DirectTaskResult[_]): Unit = {
        val info = taskInfos(tid)
        val index = info.index
        // Check if any other attempt succeeded before this and this attempt has not been handled
        if (successful(index) && killedByOtherAttempt.contains(tid)) {
          // Undo the effect on calculatedTasks and totalResultSize made earlier when
          // checking if can fetch more results
          calculatedTasks -= 1
          val resultSizeAcc = result.accumUpdates.find(a =>
            a.name == Some(InternalAccumulator.RESULT_SIZE))
          if (resultSizeAcc.isDefined) {
            totalResultSize -= resultSizeAcc.get.asInstanceOf[LongAccumulator].value
          }
    
          // Handle this task as a killed task
          handleFailedTask(tid, TaskState.KILLED,
            TaskKilled("Finish but did not commit due to another attempt succeeded"))
          return
        }
    
        info.markFinished(TaskState.FINISHED, clock.getTimeMillis())
        if (speculationEnabled) {
          successfulTaskDurations.insert(info.duration)
        }
        removeRunningTask(tid)
    
        // Kill any other attempts for the same task (since those are unnecessary now that one
        // attempt completed successfully).
        for (attemptInfo <- taskAttempts(index) if attemptInfo.running) {
          logInfo(s"Killing attempt ${attemptInfo.attemptNumber} for task ${attemptInfo.id} " +
            s"in stage ${taskSet.id} (TID ${attemptInfo.taskId}) on ${attemptInfo.host} " +
            s"as the attempt ${info.attemptNumber} succeeded on ${info.host}")
          killedByOtherAttempt += attemptInfo.taskId
          sched.backend.killTask(
            attemptInfo.taskId,
            attemptInfo.executorId,
            interruptThread = true,
            reason = "another attempt succeeded")
        }
        if (!successful(index)) {
          tasksSuccessful += 1
          logInfo(s"Finished task ${info.id} in stage ${taskSet.id} (TID ${info.taskId}) in" +
            s" ${info.duration} ms on ${info.host} (executor ${info.executorId})" +
            s" ($tasksSuccessful/$numTasks)")
          // Mark successful and stop if all the tasks have succeeded.
          successful(index) = true
          if (tasksSuccessful == numTasks) {
            isZombie = true
          }
        } else {
          logInfo("Ignoring task-finished event for " + info.id + " in stage " + taskSet.id +
            " because task " + index + " has already completed successfully")
        }
        // There may be multiple tasksets for this stage -- we let all of them know that the partition
        // was completed.  This may result in some of the tasksets getting completed.
        sched.markPartitionCompletedInAllTaskSets(stageId, tasks(index).partitionId, info)
        // This method is called by "TaskSchedulerImpl.handleSuccessfulTask" which holds the
        // "TaskSchedulerImpl" lock until exiting. To avoid the SPARK-7655 issue, we should not
        // "deserialize" the value when holding a lock to avoid blocking other threads. So we call
        // "result.value()" in "TaskResultGetter.enqueueSuccessfulTask" before reaching here.
        // Note: "result.value()" only deserializes the value when it's called at the first time, so
        // here "result.value()" just returns the value and won't block other threads.
        sched.dagScheduler.taskEnded(tasks(index), Success, result.value(), result.accumUpdates, info)
        maybeFinishTaskSet()
      }
    
    1. 将 TaskSetManager.taskInfos 中对应的 TaskState 置为 FINISHED;
    2. 从 TaskSetManager.runningTasksSet 中移除对应 Task;
    3. 将当前 Task 的其他尝试给 kill 掉;
    4. 调用 dagScheduler.taskEnded 方法报告Task完成;
    5. 调用 maybeFinishTaskSet 方法看看 TaskSet 有没有完成。

    相关文章

      网友评论

          本文标题:Spark源码:运行Task

          本文链接:https://www.haomeiwen.com/subject/yaxjnctx.html