美文网首页
Spark 提交执行源码学习

Spark 提交执行源码学习

作者: ShiPF | 来源:发表于2021-12-11 23:32 被阅读0次

    SparkSubmit 执行后,执行环境准备工作

    private def runDriver(): Unit = {
        addAmIpFilter(None, System.getenv(ApplicationConstants.APPLICATION_WEB_PROXY_BASE_ENV))
        
        userClassThread = startUserApplication()
    
        // This a bit hacky, but we need to wait until the spark.driver.port property has
        // been set by the Thread executing the user class.
        logInfo("Waiting for spark context initialization...")
        val totalWaitTime = sparkConf.get(AM_MAX_WAIT_TIME)
          val sc = ThreadUtils.awaitResult(sparkContextPromise.future,
            Duration(totalWaitTime, TimeUnit.MILLISECONDS))
          if (sc != null) {
            val rpcEnv = sc.env.rpcEnv
            val userConf = sc.getConf
            val host = userConf.get(DRIVER_HOST_ADDRESS)
            val port = userConf.get(DRIVER_PORT)
            registerAM(host, port, userConf, sc.ui.map(_.webUrl), appAttemptId)
    
            val driverRef = rpcEnv.setupEndpointRef(
              RpcAddress(host, port),
              YarnSchedulerBackend.ENDPOINT_NAME)
            createAllocator(driverRef, userConf, rpcEnv, appAttemptId, distCacheConf)
          } else {
            // Sanity check; should never happen in normal operation, since sc should only be null
            // if the user app did not create a SparkContext.
            throw new IllegalStateException("User did not initialize spark context!")
          }
          resumeDriver()
          // 等待用户线程执行完毕
          userClassThread.join()
      }
    
    /**
       * 启动用户线程
       * Start the user class, which contains the spark driver, in a separate Thread.
       * If the main routine exits cleanly or exits with System.exit(N) for any N
       * we assume it was successful, for all other cases we assume failure.
       *
       * Returns the user thread that was started.
       */
      private def startUserApplication(): Thread = {
        logInfo("Starting the user application in a separate Thread")
    
        var userArgs = args.userArgs
     
        val mainMethod = userClassLoader.loadClass(args.userClass)
          .getMethod("main", classOf[Array[String]])
    
        val userThread = new Thread {
          override def run(): Unit = {
            try {
              if (!Modifier.isStatic(mainMethod.getModifiers)) {
                logError(s"Could not find static main method in object ${args.userClass}")
                finish(FinalApplicationStatus.FAILED, ApplicationMaster.EXIT_EXCEPTION_USER_CLASS)
              } else {
                mainMethod.invoke(null, userArgs.toArray)
                finish(FinalApplicationStatus.SUCCEEDED, ApplicationMaster.EXIT_SUCCESS)
                logDebug("Done running user class")
              }
            } catch {  
        }
        userThread.setContextClassLoader(userClassLoader)
        userThread.setName("Driver")
        userThread.start()
        userThread
      }
    

    开始执行用户代码

    假设用户代码如下

    spark = SparkSession.builder \
        .config('spark.driver.memory','4g') \
        .config('spark.executor.memory','4g') \
        .config('spark.executor.instances',2) \
        .config('spark.executor.cores',2) \
        .config('spark.jars','/usr/hdp/3.1.4.0-315/hadoop/lib/hll-credential-provider-v1.0.jar')\
        .config('mapreduce.input.fileinputformat.input.dir.recursive', 'true') \
        .config('mapred.input.dir.recursive', 'true') \
        .config('spark.sql.hive.convertMetastoreOrc', 'false') \
        .config('spark.yarn.queue', 'datawarehouse') \
        .appName('yqj test') \
        .enableHiveSupport() \
        .getOrCreate()
        
    sql = "select count(*) from ods.check_hive2_not_delete group by cityid"
    sql_run = spark.sql(sql)
    sql_run.show()
    

    Step01,构建SparkSession对象

    /**
         * Gets an existing [[SparkSession]] or, if there is no existing one, creates a new
         * one based on the options set in this builder.
         *
         * This method first checks whether there is a valid thread-local SparkSession,
         * and if yes, return that one. It then checks whether there is a valid global
         * default SparkSession, and if yes, return that one. If no valid global default
         * SparkSession exists, the method creates a new SparkSession and assigns the
         * newly created SparkSession as the global default.
         *
         * In case an existing SparkSession is returned, the non-static config options specified in
         * this builder will be applied to the existing SparkSession.
         *  SparkSession对象可以重用,在Cli模式中
         * @since 2.0.0
         */
    def getOrCreate(): SparkSession = synchronized {
          val sparkConf = new SparkConf()
          options.foreach { case (k, v) => sparkConf.set(k, v) }
    
         
          // Get the session from current thread's active session.
          var session = activeThreadSession.get()
          if ((session ne null) && !session.sparkContext.isStopped) {
            applyModifiableSettings(session, new java.util.HashMap[String, String](options.asJava))
            return session
          }
    
          // Global synchronization so we will only set the default session once.
          SparkSession.synchronized {
            // If the current thread does not have an active session, get it from the global session.
            session = defaultSession.get()
            if ((session ne null) && !session.sparkContext.isStopped) {
              applyModifiableSettings(session, new java.util.HashMap[String, String](options.asJava))
              return session
            }
    
            // No active nor global default session. Create a new one.
            val sparkContext = userSuppliedContext.getOrElse {
              // set a random app name if not given.
              if (!sparkConf.contains("spark.app.name")) {
                sparkConf.setAppName(java.util.UUID.randomUUID().toString)
              }
    
              SparkContext.getOrCreate(sparkConf)
              // Do not update `SparkConf` for existing `SparkContext`, as it's shared by all sessions.
            }
    
            loadExtensions(extensions)
            applyExtensions(
              sparkContext.getConf.get(StaticSQLConf.SPARK_SESSION_EXTENSIONS).getOrElse(Seq.empty),
              extensions)
            // 构造session对象,需要传入SparkContext对象和SparkConf对象
            session = new SparkSession(sparkContext, None, None, extensions, options.toMap)
            setDefaultSession(session)
            setActiveSession(session)
            registerContextListener(sparkContext)
          }
    
          return session
        }
    

    再来看下SparkSession类的结构 SessionState 是一个核心的类,很多属性从中获取

    class SparkSession private(
        @transient val sparkContext: SparkContext,
        @transient private val existingSharedState: Option[SharedState],
        @transient private val parentSessionState: Option[SessionState],
        @transient private[sql] val extensions: SparkSessionExtensions,
        @transient private[sql] val initialSessionOptions: Map[String, String])
    extends Serializable with Closeable with Logging {
      
       /**
       * State shared across sessions, including the `SparkContext`, cached data, listener,
       * and a catalog that interacts with external systems.
       *  跨会话共享对象,
       * This is internal to Spark and there is no guarantee on interface stability.
       *
       * @since 2.2.0
       */
      @Unstable
      @transient
      lazy val sharedState: SharedState = {
        existingSharedState.getOrElse(new SharedState(sparkContext, initialSessionOptions))
      }
      
      /**
       * State isolated across sessions, including SQL configurations, temporary tables, registered
       * functions, and everything else that accepts a [[org.apache.spark.sql.internal.SQLConf]].
       * If `parentSessionState` is not null, the `SessionState` will be a copy of the parent.
       * 跨会话隔离的对象
       * This is internal to Spark and there is no guarantee on interface stability.
       *
       * @since 2.2.0
       */
      @Unstable
      @transient
      lazy val sessionState: SessionState = {
        parentSessionState
          .map(_.clone(this))
          .getOrElse {
            val state = SparkSession.instantiateSessionState(
              SparkSession.sessionStateClassName(sharedState.conf),
              self)
            state
          }
      }
    }
    

    SessionState对象

    /**
     * A class that holds all session-specific state in a given [[SparkSession]].
     *
     * @param sharedState The state shared across sessions, e.g. global view manager, external catalog.
     * @param conf SQL-specific key-value configurations.
     * @param experimentalMethods Interface to add custom planning strategies and optimizers.
     * @param functionRegistry Internal catalog for managing functions registered by the user.
     * @param udfRegistration Interface exposed to the user for registering user-defined functions.
     * @param catalogBuilder a function to create an internal catalog for managing table and database
     *                       states.
     * @param sqlParser Parser that extracts expressions, plans, table identifiers etc. from SQL texts.
     * @param analyzerBuilder A function to create the logical query plan analyzer for resolving
     *                        unresolved attributes and relations.
     * @param optimizerBuilder a function to create the logical query plan optimizer.
     * @param planner Planner that converts optimized logical plans to physical plans.
     * @param streamingQueryManagerBuilder A function to create a streaming query manager to
     *                                     start and stop streaming queries.
     * @param listenerManager Interface to register custominternal/SessionState.scala
     *                        [[org.apache.spark.sql.util.QueryExecutionListener]]s.
     * @param resourceLoaderBuilder a function to create a session shared resource loader to load JARs,
     *                              files, etc.
     * @param createQueryExecution Function used to create QueryExecution objects.
     * @param createClone Function used to create clones of the session state.
     */
    
    private[sql] class SessionState(
        sharedState: SharedState,
        val conf: SQLConf,
        val experimentalMethods: ExperimentalMethods,
        val functionRegistry: FunctionRegistry,
        val tableFunctionRegistry: TableFunctionRegistry,
        val udfRegistration: UDFRegistration,
        catalogBuilder: () => SessionCatalog,
        val sqlParser: ParserInterface,
        analyzerBuilder: () => Analyzer,
        optimizerBuilder: () => Optimizer,
        val planner: SparkPlanner,
        val streamingQueryManagerBuilder: () => StreamingQueryManager,
        val listenerManager: ExecutionListenerManager,
        resourceLoaderBuilder: () => SessionResourceLoader,
        createQueryExecution: (LogicalPlan, CommandExecutionMode.Value) => QueryExecution,
        createClone: (SparkSession, SessionState) => SessionState,
        val columnarRules: Seq[ColumnarRule],
        val queryStagePrepRules: Seq[Rule[SparkPlan]])
    

    回到SparkSession,查看sql()

    /**
       * Executes a SQL query using Spark, returning the result as a `DataFrame`.
       * This API eagerly runs DDL/DML commands, but not for SELECT queries.
       *
       * @since 2.0.0
       */
      def sql(sqlText: String): DataFrame = withActive {
        val tracker = new QueryPlanningTracker
        // LogicPlan
        val plan = tracker.measurePhase(QueryPlanningTracker.PARSING) {
          sessionState.sqlParser.parsePlan(sqlText)
        }
        // 转化为DataFrame
        Dataset.ofRows(self, plan, tracker)
      }
    
    // tracker对象
    /**
     * A simple utility for tracking runtime and associated stats in query planning.
     *
     * There are two separate concepts we track:
     *
     * 1. Phases: These are broad scope phases in query planning, as listed below, i.e. analysis,
     * optimization and physical planning (just planning).
     *
     * 2. Rules: These are the individual Catalyst rules that we track. In addition to time, we also
     * track the number of invocations and effective invocations.
     */
    object QueryPlanningTracker{
       // Define a list of common phases here.
      val PARSING = "parsing"
      val ANALYSIS = "analysis"
      val OPTIMIZATION = "optimization"
      val PLANNING = "planning"
    }
    
    SQL Parse --> plan
    /** Creates LogicalPlan for a given SQL string. */
      override def parsePlan(sqlText: String): LogicalPlan = parse(sqlText) { parser =>
        astBuilder.visitSingleStatement(parser.singleStatement()) match {
          case plan: LogicalPlan => plan
          case _ =>
            val position = Origin(None, None)
            throw QueryParsingErrors.sqlStatementUnsupportedError(sqlText, position)
        }
      }
    
    plan --> DataFrame
    /** A variant of ofRows that allows passing in a tracker so we can track query parsing time. */
      def ofRows(sparkSession: SparkSession, logicalPlan: LogicalPlan, tracker: QueryPlanningTracker)
        : DataFrame = sparkSession.withActive {
        val qe = new QueryExecution(sparkSession, logicalPlan, tracker)
        qe.assertAnalyzed()
        new Dataset[Row](qe, RowEncoder(qe.analyzed.schema))
      }
    
    
    

    QueryExecution对象

    /**
     * The primary workflow for executing relational queries using Spark.  Designed to allow easy
     * access to the intermediate phases of query execution for developers.
     *
     * While this is not a public class, we should avoid changing the function names for the sake of
     * changing them, because a lot of developers use the feature for debugging.
     */
    class QueryExecution(
        val sparkSession: SparkSession,
        val logical: LogicalPlan,
        val tracker: QueryPlanningTracker = new QueryPlanningTracker,
        val mode: CommandExecutionMode.Value = CommandExecutionMode.ALL) extends Logging
    

    构建查DataSet对象加上Row信息

    /**
     * A Dataset is a strongly typed collection of domain-specific objects that can be transformed
     * in parallel using functional or relational operations. Each Dataset also has an untyped view
     * called a `DataFrame`, which is a Dataset of [[Row]].
     *
     * Operations available on Datasets are divided into transformations and actions. Transformations
     * are the ones that produce new Datasets, and actions are the ones that trigger computation and
     * return results. Example transformations include map, filter, select, and aggregate (`groupBy`).
     * Example actions count, show, or writing data out to file systems.
     *
     * Datasets are "lazy", i.e. computations are only triggered when an action is invoked. Internally,
     * a Dataset represents a logical plan that describes the computation required to produce the data.
     * When an action is invoked, Spark's query optimizer optimizes the logical plan and generates a
     * physical plan for efficient execution in a parallel and distributed manner. To explore the
     * logical plan as well as optimized physical plan, use the `explain` function.
     * 数据集是“惰性的”,即只有在调用操作时才会触发计算。在内部,数据集表示描述生成数据所需的计算的逻辑计划。
     * 当一个action 被调用时,Spark 的查询优化器会优化逻辑计划并生成一个以并行和分布式方式有效执行的物理计划。
     * 为了探索逻辑计划以及优化的物理计划,使用`explain`功能。
     *
     * @groupname basic Basic Dataset functions
     * @groupname action Actions
     * @groupname untypedrel Untyped transformations
     * @groupname typedrel Typed transformations
     *
     * @since 1.6.0
     */
    @Stable
    class Dataset[T] private[sql](
        @DeveloperApi @Unstable @transient val queryExecution: QueryExecution,
        @DeveloperApi @Unstable @transient val encoder: Encoder[T])
    

    查看一个SQL Explain的结果

    sql = "select count(*) from ods.check_hive2_not_delete group by cityid"
    sql_run = spark.sql(sql)
    sql_run.explain(True)
    
    == Parsed Logical Plan ==
    'Aggregate ['cityid], [unresolvedalias('count(1), None)]
    +- 'UnresolvedRelation `ods`.`check_hive2_not_delete`
    
    == Analyzed Logical Plan ==
    count(1): bigint
    Aggregate [cityid#85], [count(1) AS count(1)#95L]
    +- SubqueryAlias check_hive2_not_delete
       +- HiveTableRelation `ods`.`check_hive2_not_delete`, org.apache.hadoop.hive.ql.io.orc.OrcSerde, [id#84, cityid#85, lng#86, lat#87, prob#88, order_cnt#89, user_cnt#90, ratio#91, load_ratio#92, unload_ratio#93, 10m_dist_ratio#94]
    
    == Optimized Logical Plan ==
    Aggregate [cityid#85], [count(1) AS count(1)#95L]
    +- Project [cityid#85]
       +- HiveTableRelation `ods`.`check_hive2_not_delete`, org.apache.hadoop.hive.ql.io.orc.OrcSerde, [id#84, cityid#85, lng#86, lat#87, prob#88, order_cnt#89, user_cnt#90, ratio#91, load_ratio#92, unload_ratio#93, 10m_dist_ratio#94]
    
    == Physical Plan ==
    *(2) HashAggregate(keys=[cityid#85], functions=[count(1)], output=[count(1)#95L])
    +- Exchange hashpartitioning(cityid#85, 200)
       +- *(1) HashAggregate(keys=[cityid#85], functions=[partial_count(1)], output=[cityid#85, count#98L])
          +- HiveTableScan [cityid#85], HiveTableRelation `ods`.`check_hive2_not_delete`, org.apache.hadoop.hive.ql.io.orc.OrcSerde, [id#84, cityid#85, lng#86, lat#87, prob#88, order_cnt#89, user_cnt#90, ratio#91, load_ratio#92, unload_ratio#93, 10m_dist_ratio#94]
    

    Action算子最终触发SparkContext的 方法

    /**
       * Run a function on a given set of partitions in an RDD and pass the results to the given
       * handler function. This is the main entry point for all actions in Spark.
       *
       * @param rdd target RDD to run tasks on
       * @param func a function to run on each partition of the RDD
       * @param partitions set of partitions to run on; some jobs may not want to compute on all
       * partitions of the target RDD, e.g. for operations like `first()`
       * @param resultHandler callback to pass each result to
       */
      def runJob[T, U: ClassTag](
          rdd: RDD[T],
          func: (TaskContext, Iterator[T]) => U,
          partitions: Seq[Int],
          resultHandler: (Int, U) => Unit): Unit = {
        if (stopped.get()) {
          throw new IllegalStateException("SparkContext has been shutdown")
        }
        val callSite = getCallSite
        val cleanedFunc = clean(func)
        logInfo("Starting job: " + callSite.shortForm)
        if (conf.getBoolean("spark.logLineage", false)) {
          logInfo("RDD's recursive dependencies:\n" + rdd.toDebugString)
        }
        dagScheduler.runJob(rdd, cleanedFunc, partitions, callSite, resultHandler, localProperties.get)
        progressBar.foreach(_.finishAll())
        rdd.doCheckpoint()
      }
    

    最终调用的是DagScheduler的runJob

      /**
       * Submit an action job to the scheduler.
       *
       * @param rdd target RDD to run tasks on
       * @param func a function to run on each partition of the RDD
       * @param partitions set of partitions to run on; some jobs may not want to compute on all
       *   partitions of the target RDD, e.g. for operations like first()
       * @param callSite where in the user program this job was called
       * @param resultHandler callback to pass each result to
       * @param properties scheduler properties to attach to this job, e.g. fair scheduler pool name
       *
       * @return a JobWaiter object that can be used to block until the job finishes executing
       *         or can be used to cancel the job.
       *
       * @throws IllegalArgumentException when partitions ids are illegal
       */
      def submitJob[T, U](
          rdd: RDD[T],
          func: (TaskContext, Iterator[T]) => U,
          partitions: Seq[Int],
          callSite: CallSite,
          resultHandler: (Int, U) => Unit,
          properties: Properties): JobWaiter[U] = {
        // Check to make sure we are not launching a task on a partition that does not exist.
        val maxPartitions = rdd.partitions.length
       
        // SPARK-23626: `RDD.getPartitions()` can be slow, so we eagerly compute
        // `.partitions` on every RDD in the DAG to ensure that `getPartitions()`
        // is evaluated outside of the DAGScheduler's single-threaded event loop:
        eagerlyComputePartitionsForRddAndAncestors(rdd)
    
        val jobId = nextJobId.getAndIncrement()
        val func2 = func.asInstanceOf[(TaskContext, Iterator[_]) => _]
        val waiter = new JobWaiter[U](this, jobId, partitions.size, resultHandler)
        
        // 在这里进行任务的提交 入参是 jobid , rdd , function, partition...配置
        eventProcessLoop.post(JobSubmitted(
          jobId, rdd, func2, partitions.toArray, callSite, waiter,
          Utils.cloneProperties(properties)))
        waiter
      }
    

    DagScheduler 的类信息

    /**
     * The high-level scheduling layer that implements stage-oriented scheduling. It computes a DAG of
     * stages for each job, keeps track of which RDDs and stage outputs are materialized, and finds a
     * minimal schedule to run the job. It then submits stages as TaskSets to an underlying
     * TaskScheduler implementation that runs them on the cluster. A TaskSet contains fully independent
     * tasks that can run right away based on the data that's already on the cluster (e.g. map output
     * files from previous stages), though it may fail if this data becomes unavailable.
     *
     实现面向阶段调度的高级调度层。它为每个job划分为DAG,跟踪实现了哪些 RDD 和stage输出,并找到运行作业的最小时间表。然后它将阶段作为任务集提交给在集群上运行它们的底层 TaskScheduler 实现。 TaskSet 包含完全独立的任务,可以根据集群上已有的数据立即运行
     
     * Spark stages are created by breaking the RDD graph at shuffle boundaries. RDD operations with
     * "narrow" dependencies, like map() and filter(), are pipelined together into one set of tasks
     * in each stage, but operations with shuffle dependencies require multiple stages (one to write a
     * set of map output files, and another to read those files after a barrier). In the end, every
     * stage will have only shuffle dependencies on other stages, and may compute multiple operations
     * inside it. The actual pipelining of these operations happens in the RDD.compute() functions of
     * various RDDs
     *  上面是宽依赖和找依赖切换Stage的方式
     * In addition to coming up with a DAG of stages, the DAGScheduler also determines the preferred
     * locations to run each task on, based on the current cache status, and passes these to the
     * low-level TaskScheduler. Furthermore, it handles failures due to shuffle output files being
     * lost, in which case old stages may need to be resubmitted. Failures *within* a stage that are
     * not caused by shuffle file loss are handled by the TaskScheduler, which will retry each task
     * a small number of times before cancelling the whole stage.
     
     * 决定Task的运行位置,另外如果Shuffle输出文件丢失,可以进行staged数据重新计算,在重新计算整个stage之前
     
       几个重要的概念
     * When looking through this code, there are several key concepts:
     *
     *  - Jobs (represented by [[ActiveJob]]) are the top-level work items submitted to the scheduler.
     *    For example, when the user calls an action, like count(), a job will be submitted through
     *    submitJob. Each Job may require the execution of multiple stages to build intermediate data.
     *    job对象的是一个action算子,一个job包含多个stages和中间数据
      
     *  - Stages ([[Stage]]) are sets of tasks that compute intermediate results in jobs, where each
     *    task computes the same function on partitions of the same RDD. Stages are separated at shuffle
     *    boundaries, which introduce a barrier (where we must wait for the previous stage to finish to
     *    fetch outputs). There are two types of stages: [[ResultStage]], for the final stage that
     *    executes an action, and [[ShuffleMapStage]], which writes map output files for a shuffle.
     *    Stages are often shared across multiple jobs, if these jobs reuse the same RDDs.
          stage是RDD相同分区的相同function的集合,stage之间通过shuffle boundaries 切分,stage有两种,ShuffleMapStage和ResultStage . 多个job之间可以通过重用RDD做Stage的共享
     *
     *  - Tasks are individual units of work, each sent to one machine. 独立的工作单元
     *
     *  - Cache tracking: the DAGScheduler figures out which RDDs are cached to avoid recomputing them
     *    and likewise remembers which shuffle map stages have already produced output files to avoid
     *    redoing the map side of a shuffle.
          DagSheduler计算出 缓存RDD和 ShuffleMapStage输出文件信息,避免重新计算map端数据
     *
     *  - Preferred locations: the DAGScheduler also computes where to run each task in a stage based
     *    on the preferred locations of its underlying RDDs, or the location of cached or shuffle data.
     *    DAGScheduler 根据Shuffle 和缓存信息,计算出task执行的节点的最佳位置
     *
     *  - Cleanup: all data structures are cleared when the running jobs that depend on them finish,
     *    to prevent memory leaks in a long-running application.
     *    清理已经完成的依赖任务的数据信息,防止内存泄漏
     *
     * To recover from failures, the same stage might need to run multiple times, which are called
     * "attempts". If the TaskScheduler reports that a task failed because a map output file from a
     * previous stage was lost, the DAGScheduler resubmits that lost stage. This is detected through a
     * CompletionEvent with FetchFailed, or an ExecutorLost event. The DAGScheduler will wait a small
     * amount of time to see whether other nodes or tasks fail, then resubmit TaskSets for any lost
     * stage(s) that compute the missing tasks. As part of this process, we might also have to create
     * Stage objects for old (finished) stages where we previously cleaned up the Stage object. Since
     * tasks from the old attempt of a stage could still be running, care must be taken to map any
     * events received in the correct Stage object.
     * 确保从失败中恢复,相同的stage可能需要运行多次,被成为重试。如果TaskScheduler报告,一个任务因为上一个阶段stage的map out 文件丢失, DAGScheduler会重新提交lost stage.通过CompletionEvent 或者 ExecutorLost event.监听。DAGScheduler将会等待一段时间看看其他的节点或者任务是否有失败。, 然后提交taskset 计算丢失的stage。 在部分处理过程中,可能需要创建我们之前已经清理的stage对象。
     *由于来自旧阶段尝试的任务可能仍在运行,因此必须注意映射在正确的 Stage 对象中接收到的任何事件
     *
     * Here's a checklist to use when making or reviewing changes to this class:
     *
     *  - All data structures should be cleared when the jobs involving them end to avoid indefinite
     *    accumulation of state in long-running programs.
     *
     *  - When adding a new data structure, update `DAGSchedulerSuite.assertDataStructuresEmpty` to
     *    include the new structure. This will help to catch memory leaks.
     */
    private[spark] class DAGScheduler(
        private[scheduler] val sc: SparkContext,
        private[scheduler] val taskScheduler: TaskScheduler,
        listenerBus: LiveListenerBus,
        mapOutputTracker: MapOutputTrackerMaster,
        blockManagerMaster: BlockManagerMaster,
        env: SparkEnv,
        clock: Clock = new SystemClock())
    

    上面的submitJob调用的是DAGScheduler 的

     private def doOnReceive(event: DAGSchedulerEvent): Unit = event match {
        case JobSubmitted(jobId, rdd, func, partitions, callSite, listener, properties) =>
          dagScheduler.handleJobSubmitted(jobId, rdd, func, partitions, callSite, listener, properties)
    
        case MapStageSubmitted(jobId, dependency, callSite, listener, properties) =>
          dagScheduler.handleMapStageSubmitted(jobId, dependency, callSite, listener, properties)
    
        case StageCancelled(stageId, reason) =>
          dagScheduler.handleStageCancellation(stageId, reason)
       .....
     }
    
    private[scheduler] def handleJobSubmitted(jobId: Int,
          finalRDD: RDD[_],
          func: (TaskContext, Iterator[_]) => _,
          partitions: Array[Int],
          callSite: CallSite,
          listener: JobListener,
          properties: Properties): Unit = {
        var finalStage: ResultStage = null
        try {
          // New stage creation may throw an exception if, for example, jobs are run on a
          // HadoopRDD whose underlying HDFS files have been deleted.
          // 生成Stage,进行stage划分
          finalStage = createResultStage(finalRDD, func, partitions, jobId, callSite)
          
          .....
          
          val job = new ActiveJob(jobId, finalStage, callSite, listener, properties)
        clearCacheLocs()
        
        val jobSubmissionTime = clock.getTimeMillis()
        jobIdToActiveJob(jobId) = job
        activeJobs += job
        finalStage.setActiveJob(job)
        val stageIds = jobIdToStageIds(jobId).toArray
        val stageInfos = stageIds.flatMap(id => stageIdToStage.get(id).map(_.latestInfo))
          
        // 
        listenerBus.post(
          SparkListenerJobStart(job.jobId, jobSubmissionTime, stageInfos,
            Utils.cloneProperties(properties)))
        submitStage(finalStage)
      }
    

    先看下Stage切换的阶段

    /**
       * Create a ResultStage associated with the provided jobId.
       */
      private def createResultStage(
          rdd: RDD[_],
          func: (TaskContext, Iterator[_]) => _,
          partitions: Array[Int],
          jobId: Int,
          callSite: CallSite): ResultStage = {
        // 获取单层父依赖 ,和资源信息
        val (shuffleDeps, resourceProfiles) = getShuffleDependenciesAndResourceProfiles(rdd)
        val resourceProfile = mergeResourceProfilesForStage(resourceProfiles)
        checkBarrierStageWithDynamicAllocation(rdd)
        checkBarrierStageWithNumSlots(rdd, resourceProfile)
        checkBarrierStageWithRDDChainPattern(rdd, partitions.toSet.size)
        
        // 这里
        val parents = getOrCreateParentStages(shuffleDeps, jobId)
        val id = nextStageId.getAndIncrement()
        val stage = new ResultStage(id, rdd, func, partitions, parents, jobId,
          callSite, resourceProfile.id)
        stageIdToStage(id) = stage
        updateJobIdStageIdMaps(jobId, stage)
        stage
      }
    
    
    /** 这个变量维护了stage的信息
       * Mapping from shuffle dependency ID to the ShuffleMapStage that will generate the data for
       * that dependency. Only includes stages that are part of currently running job (when the job(s)
       * that require the shuffle stage complete, the mapping will be removed, and the only record of
       * the shuffle data will be in the MapOutputTracker).
       */
      private[scheduler] val shuffleIdToMapStage = new HashMap[Int, ShuffleMapStage]
    
    
    
    /**
       * Returns shuffle dependencies that are immediate parents of the given RDD and the
       * ResourceProfiles associated with the RDDs for this stage.
       *
       * This function will not return more distant ancestors for shuffle dependencies. For example,
       * if C has a shuffle dependency on B which has a shuffle dependency on A:
       *
       * A <-- B <-- C
       *
       * calling this function with rdd C will only return the B <-- C dependency.
       *
       * This function is scheduler-visible for the purpose of unit testing.
       */
      private[scheduler] def getShuffleDependenciesAndResourceProfiles(
          rdd: RDD[_]): (HashSet[ShuffleDependency[_, _, _]], HashSet[ResourceProfile]) = {
        val parents = new HashSet[ShuffleDependency[_, _, _]]
        val resourceProfiles = new HashSet[ResourceProfile]
        val visited = new HashSet[RDD[_]]
        val waitingForVisit = new ListBuffer[RDD[_]]
        waitingForVisit += rdd
        while (waitingForVisit.nonEmpty) {
          val toVisit = waitingForVisit.remove(0)
          if (!visited(toVisit)) {
            visited += toVisit
            Option(toVisit.getResourceProfile).foreach(resourceProfiles += _)
            toVisit.dependencies.foreach {
              case shuffleDep: ShuffleDependency[_, _, _] =>
                parents += shuffleDep
              case dependency =>
                waitingForVisit.prepend(dependency.rdd)
            }
          }
        }
        (parents, resourceProfiles)
      }
    
    /**
       * Get or create the list of parent stages for the given shuffle dependencies. The new
       * Stages will be created with the provided firstJobId.
       */
      private def getOrCreateParentStages(shuffleDeps: HashSet[ShuffleDependency[_, _, _]],
          firstJobId: Int): List[Stage] = {
        shuffleDeps.map { shuffleDep =>
          // ResultStage的依赖都是ShuffleMapStage,这里进行获取或者创建
          getOrCreateShuffleMapStage(shuffleDep, firstJobId)
        }.toList
      }
    
    /**
       * Gets a shuffle map stage if one exists in shuffleIdToMapStage. Otherwise, if the
       * shuffle map stage doesn't already exist, this method will create the shuffle map stage in
       * addition to any missing ancestor shuffle map stages.
       */
      private def getOrCreateShuffleMapStage(
          shuffleDep: ShuffleDependency[_, _, _],
          firstJobId: Int): ShuffleMapStage = {
        shuffleIdToMapStage.get(shuffleDep.shuffleId) match {
          case Some(stage) =>
            stage
    
          case None =>
            // Create stages for all missing ancestor shuffle dependencies.
           // 确保所有的最先依赖都存在于shuffleIdToMapStage 中,不存在则县创建祖先依赖
            getMissingAncestorShuffleDependencies(shuffleDep.rdd).foreach { dep =>
              // Even though getMissingAncestorShuffleDependencies only returns shuffle dependencies
              // that were not already in shuffleIdToMapStage, it's possible that by the time we
              // get to a particular dependency in the foreach loop, it's been added to
              // shuffleIdToMapStage by the stage creation process for an earlier dependency. See
              // SPARK-13902 for more information.
              if (!shuffleIdToMapStage.contains(dep.shuffleId)) {
                createShuffleMapStage(dep, firstJobId)
              }
            }
            // Finally, create a stage for the given shuffle dependency.
            createShuffleMapStage(shuffleDep, firstJobId)
        }
      }
    
    // 创建ShuffleMapStage 和ResultStage的code比较类似
     /**
       * Creates a ShuffleMapStage that generates the given shuffle dependency's partitions. If a
       * previously run stage generated the same shuffle data, this function will copy the output
       * locations that are still available from the previous shuffle to avoid unnecessarily
       * regenerating data.
       */
      def createShuffleMapStage[K, V, C](
          shuffleDep: ShuffleDependency[K, V, C], jobId: Int): ShuffleMapStage = {
        val rdd = shuffleDep.rdd
        val (shuffleDeps, resourceProfiles) = getShuffleDependenciesAndResourceProfiles(rdd)
        val resourceProfile = mergeResourceProfilesForStage(resourceProfiles)
        checkBarrierStageWithDynamicAllocation(rdd)
        checkBarrierStageWithNumSlots(rdd, resourceProfile)
        checkBarrierStageWithRDDChainPattern(rdd, rdd.getNumPartitions)
        val numTasks = rdd.partitions.length
        val parents = getOrCreateParentStages(shuffleDeps, jobId)
        val id = nextStageId.getAndIncrement()
        // 创建ShuffleMapStage 
        val stage = new ShuffleMapStage(
          id, rdd, numTasks, parents, jobId, rdd.creationSite, shuffleDep, mapOutputTracker,
          resourceProfile.id)
    
        stageIdToStage(id) = stage
        // 更新到shufflemap中
        shuffleIdToMapStage(shuffleDep.shuffleId) = stage
        updateJobIdStageIdMaps(jobId, stage)
    
        if (!mapOutputTracker.containsShuffle(shuffleDep.shuffleId)) {
          // Kind of ugly: need to register RDDs with the cache and map output tracker here
          // since we can't do it in the RDD constructor because # of partitions is unknown
          logInfo(s"Registering RDD ${rdd.id} (${rdd.getCreationSite}) as input to " +
            s"shuffle ${shuffleDep.shuffleId}")
          mapOutputTracker.registerShuffle(shuffleDep.shuffleId, rdd.partitions.length,
            shuffleDep.partitioner.numPartitions)
        }
        stage
      }
    

    上面的代码Stage已经全部生成,下面是submitStage(finalStage)的代码

    /** Submits stage, but first recursively submits any missing parents. */
      private def submitStage(stage: Stage): Unit = {
        val jobId = activeJobForStage(stage)
        if (jobId.isDefined) {
          logDebug(s"submitStage($stage (name=${stage.name};" +
            s"jobs=${stage.jobIds.toSeq.sorted.mkString(",")}))")
          if (!waitingStages(stage) && !runningStages(stage) && !failedStages(stage)) {
            // 丢失的parent stage提交
            val missing = getMissingParentStages(stage).sortBy(_.id)
            logDebug("missing: " + missing)
            if (missing.isEmpty) {
              logInfo("Submitting " + stage + " (" + stage.rdd + "), which has no missing parents")
              // 正常提交
              submitMissingTasks(stage, jobId.get)
            } else {
              for (parent <- missing) {
                submitStage(parent)
              }
              waitingStages += stage
            }
          }
        } else {
          abortStage(stage, "No active job for stage " + stage.id, None)
        }
      }
    
    
     /** Called when stage's parents are available and we can now do its task. */
      private def submitMissingTasks(stage: Stage, jobId: Int): Unit = {
        
      ......... 容错和最佳位置计算
      val tasks: Seq[Task[_]] = try {
          val serializedTaskMetrics = closureSerializer.serialize(stage.latestInfo.taskMetrics).array()
          stage match {
            case stage: ShuffleMapStage =>
              stage.pendingPartitions.clear()
              partitionsToCompute.map { id =>
                val locs = taskIdToLocations(id)
                val part = partitions(id)
                stage.pendingPartitions += id
                new ShuffleMapTask(stage.id, stage.latestInfo.attemptNumber,
                  taskBinary, part, locs, properties, serializedTaskMetrics, Option(jobId),
                  Option(sc.applicationId), sc.applicationAttemptId, stage.rdd.isBarrier())
              }
    
            case stage: ResultStage =>
              partitionsToCompute.map { id =>
                val p: Int = stage.partitions(id)
                val part = partitions(p)
                val locs = taskIdToLocations(id)
                new ResultTask(stage.id, stage.latestInfo.attemptNumber,
                  taskBinary, part, locs, id, properties, serializedTaskMetrics,
                  Option(jobId), Option(sc.applicationId), sc.applicationAttemptId,
                  stage.rdd.isBarrier())
              }
          }
        }
       
        // 提价task
        if (tasks.nonEmpty) {
          logInfo(s"Submitting ${tasks.size} missing tasks from $stage (${stage.rdd}) (first 15 " +
            s"tasks are for partitions ${tasks.take(15).map(_.partitionId)})")
          taskScheduler.submitTasks(new TaskSet(
            tasks.toArray, stage.id, stage.latestInfo.attemptNumber, jobId, properties,
            stage.resourceProfileId))
        } 
        
      }
      
    

    到这里,stage的划分和提价节点已经结束,后面是task的调度和执行阶段

    在Driver端,task的调度有 TaskSchedulerImpl 类负责

    /**
     * Schedules tasks for multiple types of clusters by acting through a SchedulerBackend.
     * It can also work with a local setup by using a `LocalSchedulerBackend` and setting
     * isLocal to true. It handles common logic, like determining a scheduling order across jobs, waking
     * up to launch speculative tasks, etc.
     *
     * Clients should first call initialize() and start(), then submit task sets through the
     * submitTasks method.
     *
     * THREADING: [[SchedulerBackend]]s and task-submitting clients can call this class from multiple
     * threads, so it needs locks in public API methods to maintain its state. In addition, some
     * [[SchedulerBackend]]s synchronize on themselves when they want to send events here, and then
     * acquire a lock on us, so we need to make sure that we don't try to lock the backend while
     * we are holding a lock on ourselves.  This class is called from many threads, notably:
     *   * The DAGScheduler Event Loop
     *   * The RPCHandler threads, responding to status updates from Executors
     *   * Periodic revival of all offers from the CoarseGrainedSchedulerBackend, to accommodate delay
     *      scheduling
     *   * task-result-getter threads
     *
     * CAUTION: Any non fatal exception thrown within Spark RPC framework can be swallowed.
     * Thus, throwing exception in methods like resourceOffers, statusUpdate won't fail
     * the application, but could lead to undefined behavior. Instead, we shall use method like
     * TaskSetManger.abort() to abort a stage and then fail the application (SPARK-31485).
     *
     * Delay Scheduling:
     *  Delay scheduling is an optimization that sacrifices job fairness for data locality in order to
     *  improve cluster and workload throughput. One useful definition of "delay" is how much time
     *  has passed since the TaskSet was using its fair share of resources. Since it is impractical to
     *  calculate this delay without a full simulation, the heuristic used is the time since the
     *  TaskSetManager last launched a task and has not rejected any resources due to delay scheduling
     *  since it was last offered its "fair share". A "fair share" offer is when [[resourceOffers]]'s
     *  parameter "isAllFreeResources" is set to true. A "delay scheduling reject" is when a resource
     *  is not utilized despite there being pending tasks (implemented inside [[TaskSetManager]]).
     *  The legacy heuristic only measured the time since the [[TaskSetManager]] last launched a task,
     *  and can be re-enabled by setting spark.locality.wait.legacyResetOnTaskLaunch to true.
     */
    private[spark] class TaskSchedulerImpl(
        val sc: SparkContext,
        val maxTaskFailures: Int,
        isLocal: Boolean = false,
        clock: Clock = new SystemClock)
      extends TaskScheduler with Logging
    
    /**
     * Schedules the tasks within a single TaskSet in the TaskSchedulerImpl. This class keeps track of
     * each task, retries tasks if they fail (up to a limited number of times), and
     * handles locality-aware scheduling for this TaskSet via delay scheduling. The main interfaces
     * to it are resourceOffer, which asks the TaskSet whether it wants to run a task on one node,
     * and handleSuccessfulTask/handleFailedTask, which tells it that one of its tasks changed state
     *  (e.g. finished/failed).
     *
     * THREADING: This class is designed to only be called from code with a lock on the
     * TaskScheduler (e.g. its event handlers). It should not be called from other threads.
     *
     * @param sched           the TaskSchedulerImpl associated with the TaskSetManager
     * @param taskSet         the TaskSet to manage scheduling for
     * @param maxTaskFailures if any particular task fails this number of times, the entire
     *                        task set will be aborted
     */
    private[spark] class TaskSetManager(
        sched: TaskSchedulerImpl,
        val taskSet: TaskSet,
        val maxTaskFailures: Int,
        healthTracker: Option[HealthTracker] = None,
        clock: Clock = new SystemClock()) extends Schedulable with Logging
    
    
    override def submitTasks(taskSet: TaskSet): Unit = {
        val tasks = taskSet.tasks
        logInfo("Adding task set " + taskSet.id + " with " + tasks.length + " tasks "
          + "resource profile " + taskSet.resourceProfileId)
        this.synchronized {
          val manager = createTaskSetManager(taskSet, maxTaskFailures)
          val stage = taskSet.stageId
          val stageTaskSets =
            taskSetsByStageIdAndAttempt.getOrElseUpdate(stage, new HashMap[Int, TaskSetManager])
    
          // Mark all the existing TaskSetManagers of this stage as zombie, as we are adding a new one.
          // This is necessary to handle a corner case. Let's say a stage has 10 partitions and has 2
          // TaskSetManagers: TSM1(zombie) and TSM2(active). TSM1 has a running task for partition 10
          // and it completes. TSM2 finishes tasks for partition 1-9, and thinks he is still active
          // because partition 10 is not completed yet. However, DAGScheduler gets task completion
          // events for all the 10 partitions and thinks the stage is finished. If it's a shuffle stage
          // and somehow it has missing map outputs, then DAGScheduler will resubmit it and create a
          // TSM3 for it. As a stage can't have more than one active task set managers, we must mark
          // TSM2 as zombie (it actually is).
          stageTaskSets.foreach { case (_, ts) =>
            ts.isZombie = true
          }
          stageTaskSets(taskSet.stageAttemptId) = manager
          // 添加到taskManager进行管理调度
          schedulableBuilder.addTaskSetManager(manager, manager.taskSet.properties)
    
          if (!isLocal && !hasReceivedTask) {
            starvationTimer.scheduleAtFixedRate(new TimerTask() {
              override def run(): Unit = {
                if (!hasLaunchedTask) {
                  logWarning("Initial job has not accepted any resources; " +
                    "check your cluster UI to ensure that workers are registered " +
                    "and have sufficient resources")
                } else {
                  this.cancel()
                }
              }
            }, STARVATION_TIMEOUT_MS, STARVATION_TIMEOUT_MS)
          }
          hasReceivedTask = true
        }
        backend.reviveOffers()
      }
    
    override def reviveOffers(): Unit = Utils.tryLogNonFatalError {
        driverEndpoint.send(ReviveOffers)
      }
    
    
    

    CoarseGrainedSchedulerBackend的实例就是代表Driver端的守护进程,其实也相当于自己发给自己。接收到ReviveOffers的消息后,会调用makeOffers()

    调度类

    /**
     * A scheduler backend that waits for coarse-grained executors to connect.
     * This backend holds onto each executor for the duration of the Spark job rather than relinquishing
     * executors whenever a task is done and asking the scheduler to launch a new executor for
     * each new task. Executors may be launched in a variety of ways, such as Mesos tasks for the
     * coarse-grained Mesos mode or standalone processes for Spark's standalone deploy mode
     * (spark.deploy.*).
     */
    private[spark]
    class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: RpcEnv)
      extends ExecutorAllocationClient with SchedulerBackend with Logging
    
    // Make fake resource offers on just one executor
       // Make fake resource offers on all executors
        private def makeOffers(): Unit = {
          // Make sure no executor is killed while some task is launching on it
          val taskDescs = withLock {
            // Filter out executors under killing 获取activeExecutors列表
            val activeExecutors = executorDataMap.filterKeys(isExecutorActive)
            val workOffers = activeExecutors.map {
              case (id, executorData) =>
                new WorkerOffer(id, executorData.executorHost, executorData.freeCores,
                  Some(executorData.executorAddress.hostPort),
                  executorData.resourcesInfo.map { case (rName, rInfo) =>
                    (rName, rInfo.availableAddrs.toBuffer)
                  }, executorData.resourceProfileId)
            }.toIndexedSeq
            // 分配运行资源
            scheduler.resourceOffers(workOffers, true)
          }
          if (taskDescs.nonEmpty) {
            launchTasks(taskDescs)
          }
        }
    

    TaskSchedulerImpl 类中为offers分配资源

    /**
       * Called by cluster manager to offer resources on workers. We respond by asking our active task
       * sets for tasks in order of priority. We fill each node with tasks in a round-robin manner so
       * that tasks are balanced across the cluster.
       */
      def resourceOffers(
          offers: IndexedSeq[WorkerOffer],
          isAllFreeResources: Boolean = true): Seq[Seq[TaskDescription]] = synchronized {
        
        // Mark each worker as alive and remember its hostname
        // Also track if new executor is added
        var newExecAvail = false
        for (o <- offers) {
          if (!hostToExecutors.contains(o.host)) {
            hostToExecutors(o.host) = new HashSet[String]()
          }
          if (!executorIdToRunningTaskIds.contains(o.executorId)) {
            hostToExecutors(o.host) += o.executorId
            executorAdded(o.executorId, o.host)
            executorIdToHost(o.executorId) = o.host
            executorIdToRunningTaskIds(o.executorId) = HashSet[Long]()
            newExecAvail = true
          }
        }
        val hosts = offers.map(_.host).distinct
        for ((host, Some(rack)) <- hosts.zip(getRacksForHosts(hosts))) {
          hostsByRack.getOrElseUpdate(rack, new HashSet[String]()) += host
        }
        
        // Before making any offers, include any nodes whose expireOnFailure timeout has expired. Do
        // this here to avoid a separate thread and added synchronization overhead, and also because
        // updating the excluded executors and nodes is only relevant when task offers are being made.
        healthTrackerOpt.foreach(_.applyExcludeOnFailureTimeout())
    
        val filteredOffers = healthTrackerOpt.map { healthTracker =>
          offers.filter { offer =>
            !healthTracker.isNodeExcluded(offer.host) &&
              !healthTracker.isExecutorExcluded(offer.executorId)
          }
        }.getOrElse(offers)
        // 为任务随机分配Executor,避免任务集中分配到Worker上
        val shuffledOffers = shuffleOffers(filteredOffers)
        
        // Build a list of tasks to assign to each worker.
        // Note the size estimate here might be off with different ResourceProfiles but should be
        // close estimate
        val tasks = shuffledOffers.map(o => new ArrayBuffer[TaskDescription](o.cores / CPUS_PER_TASK))
        val availableResources = shuffledOffers.map(_.resources).toArray
        val availableCpus = shuffledOffers.map(o => o.cores).toArray
        val resourceProfileIds = shuffledOffers.map(o => o.resourceProfileId).toArray
        val sortedTaskSets = rootPool.getSortedTaskSetQueue
        for (taskSet <- sortedTaskSets) {
          logDebug("parentName: %s, name: %s, runningTasks: %s".format(
            taskSet.parent.name, taskSet.name, taskSet.runningTasks))
          if (newExecAvail) {
            taskSet.executorAdded()
          }
        }
        // 就近分配
        // Take each TaskSet in our scheduling order, and then offer it to each node in increasing order
        // of locality levels so that it gets a chance to launch local tasks on all of them.
        // NOTE: the preferredLocality order: PROCESS_LOCAL, NODE_LOCAL, NO_PREF, RACK_LOCAL, ANY
        for (taskSet <- sortedTaskSets) {
          .....
           val (noDelayScheduleReject, minLocality) = resourceOfferSingleTaskSet(
                  taskSet, currentMaxLocality, shuffledOffers, availableCpus,
                  availableResources, tasks)
        }
        
        // launched within a configured time.
        if (tasks.nonEmpty) {
          hasLaunchedTask = true
        }
        return tasks.map(_.toSeq)
    

    执行Task,现在还是在Driver端

     // Launch tasks returned by a set of resource offers todo spf
        private def launchTasks(tasks: Seq[Seq[TaskDescription]]): Unit = {
          for (task <- tasks.flatten) {
            // 序列化Task
            val serializedTask = TaskDescription.encode(task)
            
            ......
             // 发送给Executor
              executorData.executorEndpoint.send(LaunchTask(new SerializableBuffer(serializedTask)))
            }
          }
        }
    

    CoarseGrainedExecutorBackend 接收消息

    override def receive: PartialFunction[Any, Unit] = {
        case RegisteredExecutor =>
          logInfo("Successfully registered with driver")
          try {
            executor = new Executor(executorId, hostname, env, getUserClassPath, isLocal = false,
              resources = _resources)
            driver.get.send(LaunchedExecutor(executorId))
          } catch {
            case NonFatal(e) =>
              exitExecutor(1, "Unable to create executor due to " + e.getMessage, e)
          }
    
        case LaunchTask(data) =>
          if (executor == null) {
            exitExecutor(1, "Received LaunchTask command but executor was null")
          } else {
            // 反序列化
            val taskDesc = TaskDescription.decode(data.value)
            logInfo("Got assigned task " + taskDesc.taskId)
            taskResources(taskDesc.taskId) = taskDesc.resources
            // 执行任务
            executor.launchTask(this, taskDesc)
          }
    
    

    Executor 类信息

    /**
     * Spark executor, backed by a threadpool to run tasks.
     *
     * This can be used with Mesos, YARN, kubernetes and the standalone scheduler.
     * An internal RPC interface is used for communication with the driver,
     * except in the case of Mesos fine-grained mode.
     */
    private[spark] class Executor(
        executorId: String,
        executorHostname: String,
        env: SparkEnv,
        userClassPath: Seq[URL] = Nil,
        isLocal: Boolean = false,
        uncaughtExceptionHandler: UncaughtExceptionHandler = new SparkUncaughtExceptionHandler,
        resources: immutable.Map[String, ResourceInformation])
      extends Logging
    
    
    // Maintains the list of running tasks.
      private val runningTasks = new ConcurrentHashMap[Long, TaskRunner]
    
    // 线程池
    private val threadPool = {
        val threadFactory = new ThreadFactoryBuilder()
          .setDaemon(true)
          .setNameFormat("Executor task launch worker-%d")
          .setThreadFactory((r: Runnable) => new UninterruptibleThread(r, "unused"))
          .build()
        Executors.newCachedThreadPool(threadFactory).asInstanceOf[ThreadPoolExecutor]
      }
    
        /**
         * The task to run. This will be set in run() by deserializing the task binary coming
         * from the driver. Once it is set, it will never be changed.
         */
     @volatile var task: Task[Any] = _
    
    
    调用方法
    
    def launchTask(context: ExecutorBackend, taskDescription: TaskDescription): Unit = {
        val tr = new TaskRunner(context, taskDescription, plugins)
        runningTasks.put(taskDescription.taskId, tr)
        threadPool.execute(tr)
        if (decommissioned) {
          log.error(s"Launching a task while in decommissioned state.")
        }
      }
    
    class TaskRunner(
          execBackend: ExecutorBackend,
          private val taskDescription: TaskDescription,
          private val plugins: Option[PluginContainer])
        extends Runnable {
          
          override def run(): Unit = {
            ....
             // 开始运行
          execBackend.statusUpdate(taskId, TaskState.RUNNING, EMPTY_BYTE_BUFFER)
            ....
            // 反序列化
            task = ser.deserialize[Task[Any]](
              taskDescription.serializedTask, Thread.currentThread.getContextClassLoader)
            task.localProperties = taskDescription.properties
            task.setTaskMemoryManager(taskMemoryManager)
            
            // 获取结果
            val value = Utils.tryWithSafeFinally {
              val res = task.run(
                taskAttemptId = taskId,
                attemptNumber = taskDescription.attemptNumber,
                metricsSystem = env.metricsSystem,
                cpus = taskDescription.cpus,
                resources = taskDescription.resources,
                plugins = plugins)
              
              .....
            // 序列化结果
            val valueBytes = resultSer.serialize(value)
           // directSend = sending directly back to the driver
            val serializedResult: ByteBuffer = {
              if (maxResultSize > 0 && resultSize > maxResultSize) {
                logWarning(s"Finished $taskName. Result is larger than maxResultSize " +
                  s"(${Utils.bytesToString(resultSize)} > ${Utils.bytesToString(maxResultSize)}), " +
                  s"dropping it.")
                ser.serialize(new IndirectTaskResult[Any](TaskResultBlockId(taskId), resultSize))
              } else if (resultSize > maxDirectResultSize) {
                val blockId = TaskResultBlockId(taskId)
                env.blockManager.putBytes(
                  blockId,
                  new ChunkedByteBuffer(serializedDirectResult.duplicate()),
                  StorageLevel.MEMORY_AND_DISK_SER)
                logInfo(s"Finished $taskName. $resultSize bytes result sent via BlockManager)")
                // blockid 发送给driver
                ser.serialize(new IndirectTaskResult[Any](blockId, resultSize))
              } else {
                logInfo(s"Finished $taskName. $resultSize bytes result sent to driver")
                serializedDirectResult
              }
            }
              
           // 向driver更新状态
            execBackend.statusUpdate(taskId, TaskState.FINISHED, serializedResult)
            }         
    

    执行逻辑是在Task中的Run方法

    /**
     * A unit of execution. We have two kinds of Task's in Spark:
     *
     *  - [[org.apache.spark.scheduler.ShuffleMapTask]]
     *  - [[org.apache.spark.scheduler.ResultTask]]
     *
     * A Spark job consists of one or more stages. The very last stage in a job consists of multiple
     * ResultTasks, while earlier stages consist of ShuffleMapTasks. A ResultTask executes the task
     * and sends the task output back to the driver application. A ShuffleMapTask executes the task
     * and divides the task output to multiple buckets (based on the task's partitioner).
     *   两种task
     * @param stageId id of the stage this task belongs to
     * @param stageAttemptId attempt id of the stage this task belongs to
     * @param partitionId index of the number in the RDD
     * @param localProperties copy of thread-local properties set by the user on the driver side.
     * @param serializedTaskMetrics a `TaskMetrics` that is created and serialized on the driver side
     *                              and sent to executor side.
     *
     * The parameters below are optional:
     * @param jobId id of the job this task belongs to
     * @param appId id of the app this task belongs to
     * @param appAttemptId attempt id of the app this task belongs to
     * @param isBarrier whether this task belongs to a barrier stage. Spark must launch all the tasks
     *                  at the same time for a barrier stage.
     */
    private[spark] abstract class Task[T](
        val stageId: Int,
        val stageAttemptId: Int,
        val partitionId: Int,
        @transient var localProperties: Properties = new Properties,
        // The default value is only used in tests.
        serializedTaskMetrics: Array[Byte] =
          SparkEnv.get.closureSerializer.newInstance().serialize(TaskMetrics.registered).array(),
        val jobId: Option[Int] = None,
        val appId: Option[String] = None,
        val appAttemptId: Option[String] = None,
        val isBarrier: Boolean = false) extends Serializable {
      
      // Task context, to be initialized in run().
      @transient var context: TaskContext = _
    
    
    
    
    /**
       * Called by [[org.apache.spark.executor.Executor]] to run this task.
       *
       * @param taskAttemptId an identifier for this task attempt that is unique within a SparkContext.
       * @param attemptNumber how many times this task has been attempted (0 for the first attempt)
       * @param resources other host resources (like gpus) that this task attempt can access
       * @return the result of the task along with updates of Accumulators.
       */
      final def run(
          taskAttemptId: Long,
          attemptNumber: Int,
          metricsSystem: MetricsSystem,
          cpus: Int,
          resources: Map[String, ResourceInformation],
          plugins: Option[PluginContainer]): T = {
        
        val taskContext = new TaskContextImpl(
          stageId,
          stageAttemptId, // stageAttemptId and stageAttemptNumber are semantically equal
          partitionId,
          taskAttemptId,
          attemptNumber,
          taskMemoryManager,
          localProperties,
          metricsSystem,
          metrics,
          cpus,
          resources)
    
        context = if (isBarrier) {
          new BarrierTaskContext(taskContext)
        } else {
          taskContext
        }
        
        ....
        // 这个抽象方法在具体的实现类中有不同的实现
        runTask(context)
        
        }
    }
    

    先来看ShuffleMapTask的实现

    /**
     * A ShuffleMapTask divides the elements of an RDD into multiple buckets (based on a partitioner
     * specified in the ShuffleDependency).
     *
     * See [[org.apache.spark.scheduler.Task]] for more information.
     *
     * @param stageId id of the stage this task belongs to
     * @param stageAttemptId attempt id of the stage this task belongs to
     * @param taskBinary broadcast version of the RDD and the ShuffleDependency. Once deserialized,
     *                   the type should be (RDD[_], ShuffleDependency[_, _, _]).
     * @param partition partition of the RDD this task is associated with
     * @param locs preferred task execution locations for locality scheduling
     * @param localProperties copy of thread-local properties set by the user on the driver side.
     * @param serializedTaskMetrics a `TaskMetrics` that is created and serialized on the driver side
     *                              and sent to executor side.
     *
     * The parameters below are optional:
     * @param jobId id of the job this task belongs to
     * @param appId id of the app this task belongs to
     * @param appAttemptId attempt id of the app this task belongs to
     * @param isBarrier whether this task belongs to a barrier stage. Spark must launch all the tasks
     *                  at the same time for a barrier stage.
     */
    private[spark] class ShuffleMapTask(
        stageId: Int,
        stageAttemptId: Int,
        taskBinary: Broadcast[Array[Byte]],
        partition: Partition,
        @transient private var locs: Seq[TaskLocation],
        localProperties: Properties,
        serializedTaskMetrics: Array[Byte],
        jobId: Option[Int] = None,
        appId: Option[String] = None,
        appAttemptId: Option[String] = None,
        isBarrier: Boolean = false)
      extends Task[MapStatus](stageId, stageAttemptId, partition.index, localProperties,
        serializedTaskMetrics, jobId, appId, appAttemptId, isBarrier)
      with Logging {
        ....
        
        override def runTask(context: TaskContext): MapStatus = {
        // Deserialize the RDD using the broadcast variable.
        
        val ser = SparkEnv.get.closureSerializer.newInstance()
        val rddAndDep = ser.deserialize[(RDD[_], ShuffleDependency[_, _, _])](
          ByteBuffer.wrap(taskBinary.value), Thread.currentThread.getContextClassLoader)
        _executorDeserializeTimeNs = System.nanoTime() - deserializeStartTimeNs
        _executorDeserializeCpuTime = if (threadMXBean.isCurrentThreadCpuTimeSupported) {
          threadMXBean.getCurrentThreadCpuTime - deserializeStartCpuTime
        } else 0L
    
        val rdd = rddAndDep._1
        val dep = rddAndDep._2
        // While we use the old shuffle fetch protocol, we use partitionId as mapId in the
        // ShuffleBlockId construction.
        val mapId = if (SparkEnv.get.conf.get(config.SHUFFLE_USE_OLD_FETCH_PROTOCOL)) {
          partitionId
        } else context.taskAttemptId()
        dep.shuffleWriterProcessor.write(rdd, dep, mapId, context, partition)
      }
    

    写数据

    /**
     * The interface for customizing shuffle write process. The driver create a ShuffleWriteProcessor
     * and put it into [[ShuffleDependency]], and executors use it in each ShuffleMapTask.
     */
    private[spark] class ShuffleWriteProcessor extends Serializable with Logging {
      
      
      /**
       * The write process for particular partition, it controls the life circle of [[ShuffleWriter]]
       * get from [[ShuffleManager]] and triggers rdd compute, finally return the [[MapStatus]] for
       * this task.
       */
      def write(
          rdd: RDD[_],
          dep: ShuffleDependency[_, _, _],
          mapId: Long,
          context: TaskContext,
          partition: Partition): MapStatus = {
        var writer: ShuffleWriter[Any, Any] = null
        try {
          val manager = SparkEnv.get.shuffleManager
          writer = manager.getWriter[Any, Any](
            dep.shuffleHandle,
            mapId,
            context,
            createMetricsReporter(context))
          //执行计算,并将结果写入本地系统的BlockManager中
          writer.write(
            rdd.iterator(partition, context).asInstanceOf[Iterator[_ <: Product2[Any, Any]]])
          // mapstatus 包含写入数据的原信息
          val mapStatus = writer.stop(success = true)
          if (mapStatus.isDefined) {
            // Initiate shuffle push process if push based shuffle is enabled
            // The map task only takes care of converting the shuffle data file into multiple
            // block push requests. It delegates pushing the blocks to a different thread-pool -
            // ShuffleBlockPusher.BLOCK_PUSHER_POOL.
            if (dep.shuffleMergeEnabled && dep.getMergerLocs.nonEmpty && !dep.shuffleMergeFinalized) {
              manager.shuffleBlockResolver match {
                case resolver: IndexShuffleBlockResolver =>
                  val dataFile = resolver.getDataFile(dep.shuffleId, mapId)
                  new ShuffleBlockPusher(SparkEnv.get.conf)
                    .initiateBlockPush(dataFile, writer.getPartitionLengths(), dep, partition.index)
                case _ =>
              }
            }
          }
          // 返回写结果的原信息
          mapStatus.get
        } 
      }
    

    ShuffleMapTask会将计算结果写入到BlockManager中,最终会返回包含相关元数据信息的MapStatus。MapStatus将成为下一阶段获取输入数据时的依据。

    再看看ResultTask的runTask的实现:

    override def runTask(context: TaskContext): U = {
        // Deserialize the RDD and the func using the broadcast variables.
        val threadMXBean = ManagementFactory.getThreadMXBean
        val deserializeStartTimeNs = System.nanoTime()
        val deserializeStartCpuTime = if (threadMXBean.isCurrentThreadCpuTimeSupported) {
          threadMXBean.getCurrentThreadCpuTime
        } else 0L
        val ser = SparkEnv.get.closureSerializer.newInstance()
        val (rdd, func) = ser.deserialize[(RDD[T], (TaskContext, Iterator[T]) => U)](
          ByteBuffer.wrap(taskBinary.value), Thread.currentThread.getContextClassLoader)
        _executorDeserializeTimeNs = System.nanoTime() - deserializeStartTimeNs
        _executorDeserializeCpuTime = if (threadMXBean.isCurrentThreadCpuTimeSupported) {
          threadMXBean.getCurrentThreadCpuTime - deserializeStartCpuTime
        } else 0L
    
        func(context, rdd.iterator(partition, context))
      }
    

    执行完成后,调用CoarseGrainedExecutorBackend.statusUpdate()。statusUpdate方法中向Driver终端点发送StatusUpdate的消息汇报任务执行结果。

    调用Driver端代码 CoarseGrainedSchedulerBackend

     override def receive: PartialFunction[Any, Unit] = {
          case StatusUpdate(executorId, taskId, state, data, resources) =>
            scheduler.statusUpdate(taskId, state, data.value)
            if (TaskState.isFinished(state)) {
              executorDataMap.get(executorId) match {
                case Some(executorInfo) =>
                  val rpId = executorInfo.resourceProfileId
                  val prof = scheduler.sc.resourceProfileManager.resourceProfileFromId(rpId)
                  val taskCpus = ResourceProfile.getTaskCpusOrDefaultForProfile(prof, conf)
                  executorInfo.freeCores += taskCpus
                  resources.foreach { case (k, v) =>
                    executorInfo.resourcesInfo.get(k).foreach { r =>
                      r.release(v.addresses)
                    }
                  }
                  makeOffers(executorId)
                case None =>
                  // Ignoring the update since we don't know about the executor.
                  logWarning(s"Ignored task status update ($taskId state $state) " +
                    s"from unknown executor with ID $executorId")
              }
            }
    

    调用taskScheduler 的更新状态方法

    def statusUpdate(tid: Long, state: TaskState, serializedData: ByteBuffer): Unit = {
        var failedExecutor: Option[String] = None
        var reason: Option[ExecutorLossReason] = None
        synchronized {
          try {
            Option(taskIdToTaskSetManager.get(tid)) match {
              case Some(taskSet) =>
                .....
                if (TaskState.isFinished(state)) {
                  cleanupTaskState(tid)
                  taskSet.removeRunningTask(tid)
                  if (state == TaskState.FINISHED) {
                     // 执行成功
                    taskResultGetter.enqueueSuccessfulTask(taskSet, tid, serializedData)
                  } else if (Set(TaskState.FAILED, TaskState.KILLED, TaskState.LOST).contains(state)) {
                    taskResultGetter.enqueueFailedTask(taskSet, tid, state, serializedData)
                  }
                }
             
          } 
      }
    
    def enqueueSuccessfulTask(
          taskSetManager: TaskSetManager,
          tid: Long,
          serializedData: ByteBuffer): Unit = {
        getTaskResultExecutor.execute(new Runnable {
          override def run(): Unit = Utils.logUncaughtExceptions {
            try {
              val (result, size) = serializer.get().deserialize[TaskResult[_]](serializedData) match {
                case directResult: DirectTaskResult[_] =>
                  if (!taskSetManager.canFetchMoreResults(serializedData.limit())) {
                    // kill the task so that it will not become zombie task
                    scheduler.handleFailedTask(taskSetManager, tid, TaskState.KILLED, TaskKilled(
                      "Tasks result size has exceeded maxResultSize"))
                    return
                  }
                  // deserialize "value" without holding any lock so that it won't block other threads.
                  // We should call it here, so that when it's called again in
                  // "TaskSetManager.handleSuccessfulTask", it does not need to deserialize the value.
                  directResult.value(taskResultSerializer.get())
                  (directResult, serializedData.limit())
                  // 从远程获取结果
                case IndirectTaskResult(blockId, size) =>
                  if (!taskSetManager.canFetchMoreResults(size)) {
                    // dropped by executor if size is larger than maxResultSize
                    sparkEnv.blockManager.master.removeBlock(blockId)
                    // kill the task so that it will not become zombie task
                    scheduler.handleFailedTask(taskSetManager, tid, TaskState.KILLED, TaskKilled(
                      "Tasks result size has exceeded maxResultSize"))
                    return
                  }
                  logDebug(s"Fetching indirect task result for ${taskSetManager.taskName(tid)}")
                  scheduler.handleTaskGettingResult(taskSetManager, tid)
                  // 这句
                  val serializedTaskResult = sparkEnv.blockManager.getRemoteBytes(blockId)
                  if (serializedTaskResult.isEmpty) {
                    /* We won't be able to get the task result if the machine that ran the task failed
                     * between when the task ended and when we tried to fetch the result, or if the
                     * block manager had to flush the result. */
                    scheduler.handleFailedTask(
                      taskSetManager, tid, TaskState.FINISHED, TaskResultLost)
                    return
                  }
                  val deserializedResult = serializer.get().deserialize[DirectTaskResult[_]](
                    serializedTaskResult.get.toByteBuffer)
                  // force deserialization of referenced value
                  deserializedResult.value(taskResultSerializer.get())
                  sparkEnv.blockManager.master.removeBlock(blockId)
                  (deserializedResult, size)
              }
    

    远程获取结果

    /**
       * Marks the task as getting result and notifies the DAG Scheduler
       */
      def handleTaskGettingResult(tid: Long): Unit = {
        val info = taskInfos(tid)
        info.markGettingResult(clock.getTimeMillis())
        sched.dagScheduler.taskGettingResult(info)
      }
    

    参考文章
    https://blog.csdn.net/yxf19034516/article/details/112617702
    https://blog.csdn.net/weixin_43878293/article/details/101027788

    相关文章

      网友评论

          本文标题:Spark 提交执行源码学习

          本文链接:https://www.haomeiwen.com/subject/axwlfrtx.html