美文网首页
Spark源码:启动Master

Spark源码:启动Master

作者: Jorvi | 来源:发表于2019-12-17 15:36 被阅读0次

    源码目录


    1 start-master.sh

    -- spark/sbin/start-master.sh
    
    CLASS="org.apache.spark.deploy.master.Master"
    
    "${SPARK_HOME}/sbin"/spark-daemon.sh start $CLASS 1 \
      --host $SPARK_MASTER_HOST --port $SPARK_MASTER_PORT --webui-port $SPARK_MASTER_WEBUI_PORT \
      $ORIGINAL_ARGS
    
    

    2 调用主函数

    • 进入org.apache.spark.deploy.master.Master.scala
      def main(argStrings: Array[String]) {
        Thread.setDefaultUncaughtExceptionHandler(new SparkUncaughtExceptionHandler(
          exitOnUncaughtException = false))
        Utils.initDaemon(log)
        val conf = new SparkConf
        val args = new MasterArguments(argStrings, conf)
        val (rpcEnv, _, _) = startRpcEnvAndEndpoint(args.host, args.port, args.webUiPort, conf)
        rpcEnv.awaitTermination()
      }
    
    
    1. val conf = new SparkConf
      使用ConcurrentHashMap[String, String]保存配置信息,将system properties内以spark开头的配置放入到ConcurrentHashMap。

    2. val args = new MasterArguments(argStrings, conf)
      解析命令行中的参数,加载默认参数,生成Master参数。

    3. val (rpcEnv, _, _) = startRpcEnvAndEndpoint(args.host, args.port, args.webUiPort, conf)
      创建RpcEnv、注册RpcEndpoint(关键部分)。

    4. rpcEnv.awaitTermination()
      运行直到RpcEnv关闭。


    3 重点分析 startRpcEnvAndEndpoint

    • 进入org.apache.spark.deploy.master.Master.scala
      /**
       * Start the Master and return a three tuple of:
       *   (1) The Master RpcEnv
       *   (2) The web UI bound port
       *   (3) The REST server bound port, if any
       */
      def startRpcEnvAndEndpoint(
          host: String,
          port: Int,
          webUiPort: Int,
          conf: SparkConf): (RpcEnv, Int, Option[Int]) = {
        val securityMgr = new SecurityManager(conf)
        val rpcEnv = RpcEnv.create(SYSTEM_NAME, host, port, conf, securityMgr)
        val masterEndpoint = rpcEnv.setupEndpoint(ENDPOINT_NAME,
          new Master(rpcEnv, rpcEnv.address, webUiPort, securityMgr, conf))
        val portsResponse = masterEndpoint.askSync[BoundPortsResponse](BoundPortsRequest)
        (rpcEnv, portsResponse.webUIPort, portsResponse.restPort)
      }
    
    1. val securityMgr = new SecurityManager(conf)
      创建SecurityManager,对账号、权限以及身份认证进行设置和管理。

    2. val rpcEnv = RpcEnv.create(SYSTEM_NAME, host, port, conf, securityMgr)
      创建RpcEnv。

    3. val masterEndpoint = rpcEnv.setupEndpoint(ENDPOINT_NAME, new Master(rpcEnv, rpcEnv.address, webUiPort, securityMgr, conf))
      创建 RpcEndpoint,并注册到 RpcEnv 上返回 RpcEndpointRef。

    4. val portsResponse = masterEndpoint.askSync[BoundPortsResponse](BoundPortsRequest)
      RpcEndpointRef(masterEndpoint) 同步发送消息(BoundPortsRequest) 给对应的 RpcEndpoint(Master).receiveAndReply,然后超时等待返回结果。

    5. (rpcEnv, portsResponse.webUIPort, portsResponse.restPort)
      返回结果 ( MasterRpcEnv, webUIPort, RESTServerPort(如果有) )。


    3.1 分析创建RpcEnv过程

    • 进入org.apache.spark.rpc.RpcEnv.scala
    private[spark] object RpcEnv {
    
      def create(
          name: String,
          host: String,
          port: Int,
          conf: SparkConf,
          securityManager: SecurityManager,
          clientMode: Boolean = false): RpcEnv = {
        create(name, host, host, port, conf, securityManager, 0, clientMode)
      }
    
      def create(
          name: String,
          bindAddress: String,
          advertiseAddress: String,
          port: Int,
          conf: SparkConf,
          securityManager: SecurityManager,
          numUsableCores: Int,
          clientMode: Boolean): RpcEnv = {
        val config = RpcEnvConfig(conf, name, bindAddress, advertiseAddress, port, securityManager,
          numUsableCores, clientMode)
        new NettyRpcEnvFactory().create(config)
      }
    }
    

    构造 RpcEnvConfig,然后利用 NettyRpcEnvFactory 工厂类创建 NettyRpcEnv。

    • 进入org.apache.spark.rpc.netty.NettyRpcEnvFactory.scala
    private[rpc] class NettyRpcEnvFactory extends RpcEnvFactory with Logging {
    
      def create(config: RpcEnvConfig): RpcEnv = {
        val sparkConf = config.conf
        // Use JavaSerializerInstance in multiple threads is safe. However, if we plan to support
        // KryoSerializer in future, we have to use ThreadLocal to store SerializerInstance
        val javaSerializerInstance =
          new JavaSerializer(sparkConf).newInstance().asInstanceOf[JavaSerializerInstance]
        val nettyEnv =
          new NettyRpcEnv(sparkConf, javaSerializerInstance, config.advertiseAddress,
            config.securityManager, config.numUsableCores)
        if (!config.clientMode) {
          val startNettyRpcEnv: Int => (NettyRpcEnv, Int) = { actualPort =>
            nettyEnv.startServer(config.bindAddress, actualPort)
            (nettyEnv, nettyEnv.address.port)
          }
          try {
            Utils.startServiceOnPort(config.port, startNettyRpcEnv, sparkConf, config.name)._1
          } catch {
            case NonFatal(e) =>
              nettyEnv.shutdown()
              throw e
          }
        }
        nettyEnv
      }
    }
    
    1. 创建NettyRpcEnv对象;
    2. 从config中获取clientMode属性,如果clientMode为否,则表示该RpcEnv创建在Server端,于是调用Utils.startServiceOnPort()启动服务,其又会调用函数startNettyRpcEnv: Int => (NettyRpcEnv, Int)
    3. 在函数 startNettyRpcEnv 中又会调用 NettyRpcEnv.startServer(),该方法会创建TransportServer;
    4. 返回NettyRpcEnv。

    3.1.1 创建NettyRpcEnv对象

    • 进入org.apache.spark.rpc.netty.NettyRpcEnv.scala
    private[netty] class NettyRpcEnv(
        val conf: SparkConf,
        javaSerializerInstance: JavaSerializerInstance,
        host: String,
        securityManager: SecurityManager,
        numUsableCores: Int) extends RpcEnv(conf) with Logging {
    
      private[netty] val transportConf = SparkTransportConf.fromSparkConf(
        conf.clone.set("spark.rpc.io.numConnectionsPerPeer", "1"),
        "rpc",
        conf.getInt("spark.rpc.io.threads", numUsableCores))
    
      private val dispatcher: Dispatcher = new Dispatcher(this, numUsableCores)
    
      private val streamManager = new NettyStreamManager(this)
    
      private val transportContext = new TransportContext(transportConf,
        new NettyRpcHandler(dispatcher, this, streamManager))
    
      // 省略
    }
    
    1. 创建NettyRpcEnv时会在其内部创建Dispatcher、NettyStreamManager、TransportContext等;
    2. 在创建TransportContext时还会创建NettyRpcHandler,用于将传入的RPC请求分发到注册的endpoints上去处理。

    3.1.2 Utils.startServiceOnPort()

    • 进入org.apache.spark.util.Utils.Utils.scala
      /**
       * Attempt to start a service on the given port, or fail after a number of attempts.
       * Each subsequent attempt uses 1 + the port used in the previous attempt (unless the port is 0).
       *
       * @param startPort The initial port to start the service on.
       * @param startService Function to start service on a given port.
       *                     This is expected to throw java.net.BindException on port collision.
       * @param conf A SparkConf used to get the maximum number of retries when binding to a port.
       * @param serviceName Name of the service.
       * @return (service: T, port: Int)
       */
      def startServiceOnPort[T](
          startPort: Int,
          startService: Int => (T, Int),
          conf: SparkConf,
          serviceName: String = ""): (T, Int) = {
    
        require(startPort == 0 || (1024 <= startPort && startPort < 65536),
          "startPort should be between 1024 and 65535 (inclusive), or 0 for a random free port.")
    
        val serviceString = if (serviceName.isEmpty) "" else s" '$serviceName'"
        val maxRetries = portMaxRetries(conf)
        for (offset <- 0 to maxRetries) {
          // Do not increment port if startPort is 0, which is treated as a special port
          val tryPort = if (startPort == 0) {
            startPort
          } else {
            userPort(startPort, offset)
          }
          try {
            val (service, port) = startService(tryPort)
            logInfo(s"Successfully started service$serviceString on port $port.")
            return (service, port)
          } catch {
            case e: Exception if isBindCollision(e) =>
              if (offset >= maxRetries) {
                val exceptionMessage = if (startPort == 0) {
                  s"${e.getMessage}: Service$serviceString failed after " +
                    s"$maxRetries retries (on a random free port)! " +
                    s"Consider explicitly setting the appropriate binding address for " +
                    s"the service$serviceString (for example spark.driver.bindAddress " +
                    s"for SparkDriver) to the correct binding address."
                } else {
                  s"${e.getMessage}: Service$serviceString failed after " +
                    s"$maxRetries retries (starting from $startPort)! Consider explicitly setting " +
                    s"the appropriate port for the service$serviceString (for example spark.ui.port " +
                    s"for SparkUI) to an available port or increasing spark.port.maxRetries."
                }
                val exception = new BindException(exceptionMessage)
                // restore original stack trace
                exception.setStackTrace(e.getStackTrace)
                throw exception
              }
              if (startPort == 0) {
                // As startPort 0 is for a random free port, it is most possibly binding address is
                // not correct.
                logWarning(s"Service$serviceString could not bind on a random free port. " +
                  "You may check whether configuring an appropriate binding address.")
              } else {
                logWarning(s"Service$serviceString could not bind on port $tryPort. " +
                  s"Attempting port ${tryPort + 1}.")
              }
          }
        }
        // Should never happen
        throw new SparkException(s"Failed to start service$serviceString on port $startPort")
      }
    

    传入参数:
    (1) startPort:基于SparkConf构建的配置RpcEnvConfig中配置的端口号作为起始端口号
    (2) startService:之前定义的函数 val startNettyRpcEnv: Int => (NettyRpcEnv, Int)
    (3) conf:sparkConf
    (4) serviceName:服务名("sparkMaster")

    逻辑:
    (1) 校验startPort
    (2) 尝试 0 to maxRetries 次,每次设置一个tryPort(有自定义的设置规则)传入startService,尝试启动服务
    (3) 如果启动服务成功,则返回
    (4) 如果超过maxRetries次,仍未启动成功,则抛出异常

    startService中调用 nettyEnv.startServer(config.bindAddress, actualPort) 尝试启动服务。

    3.1.3 NettyRpcEnv.startServer()

    • 进入org.apache.spark.rpc.netty.NettyRpcEnv.scala
      def startServer(bindAddress: String, port: Int): Unit = {
        val bootstraps: java.util.List[TransportServerBootstrap] =
          if (securityManager.isAuthenticationEnabled()) {
            java.util.Arrays.asList(new AuthServerBootstrap(transportConf, securityManager))
          } else {
            java.util.Collections.emptyList()
          }
        server = transportContext.createServer(bindAddress, port, bootstraps)
        dispatcher.registerRpcEndpoint(
          RpcEndpointVerifier.NAME, new RpcEndpointVerifier(this, dispatcher))
      }
    
    
    1. TransportContext 创建 TransportServer;
    2. 在 Dispatcher 上注册 RpcEndpointVerifier(注册RpcEndpoint的流程都一样,详见后文)。
    • 进入org.apache.spark.network.TransportContext.java
      /** Create a server which will attempt to bind to a specific host and port. */
      public TransportServer createServer(
          String host, int port, List<TransportServerBootstrap> bootstraps) {
        return new TransportServer(this, host, port, rpcHandler, bootstraps);
      }
    
    • 进入org.apache.spark.network.server.TransportServer.java
      /**
       * Creates a TransportServer that binds to the given host and the given port, or to any available
       * if 0. If you don't want to bind to any special host, set "hostToBind" to null.
       * */
      public TransportServer(
          TransportContext context,
          String hostToBind,
          int portToBind,
          RpcHandler appRpcHandler,
          List<TransportServerBootstrap> bootstraps) {
        this.context = context;
        this.conf = context.getConf();
        this.appRpcHandler = appRpcHandler;
        this.bootstraps = Lists.newArrayList(Preconditions.checkNotNull(bootstraps));
    
        boolean shouldClose = true;
        try {
          init(hostToBind, portToBind);
          shouldClose = false;
        } finally {
          if (shouldClose) {
            JavaUtils.closeQuietly(this);
          }
        }
      }
    
    
      private void init(String hostToBind, int portToBind) {
    
        IOMode ioMode = IOMode.valueOf(conf.ioMode());
        EventLoopGroup bossGroup =
          NettyUtils.createEventLoop(ioMode, conf.serverThreads(), conf.getModuleName() + "-server");
        EventLoopGroup workerGroup = bossGroup;
    
        PooledByteBufAllocator allocator = NettyUtils.createPooledByteBufAllocator(
          conf.preferDirectBufs(), true /* allowCache */, conf.serverThreads());
    
        bootstrap = new ServerBootstrap()
          .group(bossGroup, workerGroup)
          .channel(NettyUtils.getServerChannelClass(ioMode))
          .option(ChannelOption.ALLOCATOR, allocator)
          .option(ChannelOption.SO_REUSEADDR, !SystemUtils.IS_OS_WINDOWS)
          .childOption(ChannelOption.ALLOCATOR, allocator);
    
        this.metrics = new NettyMemoryMetrics(
          allocator, conf.getModuleName() + "-server", conf);
    
        if (conf.backLog() > 0) {
          bootstrap.option(ChannelOption.SO_BACKLOG, conf.backLog());
        }
    
        if (conf.receiveBuf() > 0) {
          bootstrap.childOption(ChannelOption.SO_RCVBUF, conf.receiveBuf());
        }
    
        if (conf.sendBuf() > 0) {
          bootstrap.childOption(ChannelOption.SO_SNDBUF, conf.sendBuf());
        }
    
        bootstrap.childHandler(new ChannelInitializer<SocketChannel>() {
          @Override
          protected void initChannel(SocketChannel ch) {
            logger.debug("New connection accepted for remote address {}.", ch.remoteAddress());
    
            RpcHandler rpcHandler = appRpcHandler;
            for (TransportServerBootstrap bootstrap : bootstraps) {
              rpcHandler = bootstrap.doBootstrap(ch, rpcHandler);
            }
            context.initializePipeline(ch, rpcHandler);
          }
        });
    
        InetSocketAddress address = hostToBind == null ?
            new InetSocketAddress(portToBind): new InetSocketAddress(hostToBind, portToBind);
        channelFuture = bootstrap.bind(address);
        channelFuture.syncUninterruptibly();
    
        port = ((InetSocketAddress) channelFuture.channel().localAddress()).getPort();
        logger.debug("Shuffle server started on port: {}", port);
      }
    

    在初始化 TransportServer 阶段:

    1. 基于Netty API初始化ServerBootstrap,设置管道初始化器ChannelInitializer到ServerBootstrap内部的ChannelHandler;
    2. 创建InetSocketAddress,ServerBootstrap绑定InetSocketAddress。

    初始化管道:

    • 进入org.apache.spark.network.TransportContext.java
      /**
       * Initializes a client or server Netty Channel Pipeline which encodes/decodes messages and
       * has a {@link org.apache.spark.network.server.TransportChannelHandler} to handle request or
       * response messages.
       *
       * @param channel The channel to initialize.
       * @param channelRpcHandler The RPC handler to use for the channel.
       *
       * @return Returns the created TransportChannelHandler, which includes a TransportClient that can
       * be used to communicate on this channel. The TransportClient is directly associated with a
       * ChannelHandler to ensure all users of the same channel get the same TransportClient object.
       */
      public TransportChannelHandler initializePipeline(
          SocketChannel channel,
          RpcHandler channelRpcHandler) {
        try {
          TransportChannelHandler channelHandler = createChannelHandler(channel, channelRpcHandler);
          channel.pipeline()
            .addLast("encoder", ENCODER)
            .addLast(TransportFrameDecoder.HANDLER_NAME, NettyUtils.createFrameDecoder())
            .addLast("decoder", DECODER)
            .addLast("idleStateHandler", new IdleStateHandler(0, 0, conf.connectionTimeoutMs() / 1000))
            // NOTE: Chunks are currently guaranteed to be returned in the order of request, but this
            // would require more logic to guarantee if this were not part of the same event loop.
            .addLast("handler", channelHandler);
          return channelHandler;
        } catch (RuntimeException e) {
          logger.error("Error while initializing Netty pipeline", e);
          throw e;
        }
      }
    
    1. 创建TransportChannelHandler;
    2. SocketChannel.pipeline增加TransportChannelHandler。

    3.2 创建并注册 Master(RpcEndpoint)

    • 进入org.apache.spark.deploy.master.Master.scala
    private[deploy] class Master(
        override val rpcEnv: RpcEnv,
        address: RpcAddress,
        webUiPort: Int,
        val securityMgr: SecurityManager,
        val conf: SparkConf)
      extends ThreadSafeRpcEndpoint with Logging with LeaderElectable {
       
        // 省略
    
    }
    

    Master 继承了 ThreadSafeRpcEndpoint,是一个 RpcEndpoint。

    • 进入org.apache.spark.rpc.netty.NettyRpcEnv.scala
      override def setupEndpoint(name: String, endpoint: RpcEndpoint): RpcEndpointRef = {
        dispatcher.registerRpcEndpoint(name, endpoint)
      }
    
    • 进入org.apache.spark.rpc.netty.Dispatcher.scala
      def registerRpcEndpoint(name: String, endpoint: RpcEndpoint): NettyRpcEndpointRef = {
        val addr = RpcEndpointAddress(nettyEnv.address, name)
        val endpointRef = new NettyRpcEndpointRef(nettyEnv.conf, addr, nettyEnv)
        synchronized {
          if (stopped) {
            throw new IllegalStateException("RpcEnv has been stopped")
          }
          if (endpoints.putIfAbsent(name, new EndpointData(name, endpoint, endpointRef)) != null) {
            throw new IllegalArgumentException(s"There is already an RpcEndpoint called $name")
          }
          val data = endpoints.get(name)
          endpointRefs.put(data.endpoint, data.ref)
          receivers.offer(data)  // for the OnStart message
        }
        endpointRef
      }
    
    1. 创建相应的NettyRpcEndpointRef,在创建 NettyRpcEndpointRef 时,会传入三个参数:SparkConf、RpcEndpointAddress 和 NettyRpcEnv,在这里,RpcEndpointAddress就是NettyRpcEnv的地址;
    2. 构建EndpointData(name, endpoint, endpointRef)并保存到endpoints(ConcurrentMap[String, EndpointData])中;
    3. 同时将EndpointData放入到receivers(LinkedBlockingQueue[EndpointData])中(EndpointData加入了队列,在哪里取出来处理呢?见后文"Dispatcher里消息处理过程")
    4. 返回 NettyRpcEndpointRef。

    看一看 EndpointData

    • 进入org.apache.spark.rpc.netty.Dispatcher.EndpointData.scala
      private class EndpointData(
          val name: String,
          val endpoint: RpcEndpoint,
          val ref: NettyRpcEndpointRef) {
        val inbox = new Inbox(ref, endpoint)
      }
    

    在构建 EndpointData 时会创建Inbox,再看看 Inbox

    • 进入org.apache.spark.rpc.netty.Inbox.scala
    private[netty] class Inbox(
        val endpointRef: NettyRpcEndpointRef,
        val endpoint: RpcEndpoint)
      extends Logging {
    
      @GuardedBy("this")
      protected val messages = new java.util.LinkedList[InboxMessage]()
    
      // OnStart should be the first message to process
      inbox.synchronized {
        messages.add(OnStart)
      }
    }
    

    在构建Inbox时,会声明一个LinkedList[InboxMessage](messages),同时会将Onstart加入到 messages 中,这样在创建Inbox时就将OnStart加入队列,可以保证OnStart第一个被处理。

    Onstart 加入了队列,又在哪里取出来处理呢?见后文"Dispatcher里消息处理过程"


    3.3 同步发送消息

    • 进入org.apache.spark.rpc.RpcEndpointRef.scala
      /**
       * Send a message to the corresponding [[RpcEndpoint.receiveAndReply]] and get its result within a
       * default timeout, throw an exception if this fails.
       *
       * Note: this is a blocking action which may cost a lot of time,  so don't call it in a message
       * loop of [[RpcEndpoint]].
    
       * @param message the message to send
       * @tparam T type of the reply message
       * @return the reply message from the corresponding [[RpcEndpoint]]
       */
      def askSync[T: ClassTag](message: Any): T = askSync(message, defaultAskTimeout)
    
    
      /**
       * Send a message to the corresponding [[RpcEndpoint.receiveAndReply]] and get its result within a
       * specified timeout, throw an exception if this fails.
       *
       * Note: this is a blocking action which may cost a lot of time, so don't call it in a message
       * loop of [[RpcEndpoint]].
       *
       * @param message the message to send
       * @param timeout the timeout duration
       * @tparam T type of the reply message
       * @return the reply message from the corresponding [[RpcEndpoint]]
       */
      def askSync[T: ClassTag](message: Any, timeout: RpcTimeout): T = {
        val future = ask[T](message, timeout)
        timeout.awaitResult(future)
      }
    
    
    • 进入org.apache.spark.rpc.netty.NettyRpcEndpointRef.scala
    private[netty] class NettyRpcEndpointRef(
        @transient private val conf: SparkConf,
        private val endpointAddress: RpcEndpointAddress,
        @transient @volatile private var nettyEnv: NettyRpcEnv) extends RpcEndpointRef(conf) {
    
      override def ask[T: ClassTag](message: Any, timeout: RpcTimeout): Future[T] = {
        nettyEnv.ask(new RequestMessage(nettyEnv.address, this, message), timeout)
      }
    
    }
    
    1. 新建 RequestMessage
    2. 调用 NettyRpcEnv 超时发送消息
    • 进入org.apache.spark.rpc.netty.RequestMessage.scala
    private[netty] class RequestMessage(
        val senderAddress: RpcAddress,
        val receiver: NettyRpcEndpointRef,
        val content: Any) {
    
        // 省略
    
    }
    

    参数说明:
    (1)val senderAddress: RpcAddress:消息的发送方地址,消息由NettyRpcEnv发送,因此发送方地址为NettyRpcEnv的地址,即RpcAddress;
    (2)val receiver: NettyRpcEndpointRef:消息的接收方,消息发送给NettyRpcEndpointRef,进而找出对应的RpcEndpoint来处理此消息;
    (3)val content: Any:消息内容

    具体的,在new RequestMessage(nettyEnv.address, this, message)中:
    (1)senderAddress是nettyEnv.address,表示NettyRpcEnv的地址,在启动Master时会在Master节点上创建一个Master NettyRpcEnv,此处的地址就是Master的地址;
    (2)receiver是this,表示调用NettyRpcEnv.ask方法的NettyRpcEndpointRef,即 masterEndpoint;
    (3)content是message,表示消息内容。

    • 进入org.apache.spark.rpc.netty.NettyRpcEnv.scala
      private[netty] def ask[T: ClassTag](message: RequestMessage, timeout: RpcTimeout): Future[T] = {
        val promise = Promise[Any]()
        val remoteAddr = message.receiver.address
    
        def onFailure(e: Throwable): Unit = {
          if (!promise.tryFailure(e)) {
            e match {
              case e : RpcEnvStoppedException => logDebug (s"Ignored failure: $e")
              case _ => logWarning(s"Ignored failure: $e")
            }
          }
        }
    
        def onSuccess(reply: Any): Unit = reply match {
          case RpcFailure(e) => onFailure(e)
          case rpcReply =>
            if (!promise.trySuccess(rpcReply)) {
              logWarning(s"Ignored message: $reply")
            }
        }
    
        try {
          if (remoteAddr == address) {
            val p = Promise[Any]()
            p.future.onComplete {
              case Success(response) => onSuccess(response)
              case Failure(e) => onFailure(e)
            }(ThreadUtils.sameThread)
            dispatcher.postLocalMessage(message, p)
          } else {
            val rpcMessage = RpcOutboxMessage(message.serialize(this),
              onFailure,
              (client, response) => onSuccess(deserialize[Any](client, response)))
            postToOutbox(message.receiver, rpcMessage)
            promise.future.failed.foreach {
              case _: TimeoutException => rpcMessage.onTimeout()
              case _ =>
            }(ThreadUtils.sameThread)
          }
    
          val timeoutCancelable = timeoutScheduler.schedule(new Runnable {
            override def run(): Unit = {
              onFailure(new TimeoutException(s"Cannot receive any reply from ${remoteAddr} " +
                s"in ${timeout.duration}"))
            }
          }, timeout.duration.toNanos, TimeUnit.NANOSECONDS)
          promise.future.onComplete { v =>
            timeoutCancelable.cancel(true)
          }(ThreadUtils.sameThread)
        } catch {
          case NonFatal(e) =>
            onFailure(e)
        }
        promise.future.mapTo[T].recover(timeout.addMessageIfTimeout)(ThreadUtils.sameThread)
      }
    
    1. 从RequestMessage中取出 receiver.address 作为remoteAddr,即消息接收方地址;
    2. 判断 "消息接收方地址" 与 "当前发送消息的NettyRpcEnv的地址" 是否相同;
    3. 如果相同,表示处理消息的RpcEndpoint就注册在当前NettyRpcEnv中(对应的RpcEndpoint和RpcEndpointRef总在相同的RpcEndpoint中),则新建Promise对象,为其future设置完成时的回调函数,然后利用NettyRpcEnv内部的Dispatcher的postLocalMessage方法投递消息到本地;
    4. 如果不同,表示处理消息的RpcEndpoint注册在其他NettyRpcEnv中,则新建RpcOutboxMessage,然后调用postToOutbox方法投递消息到Outbox;
    5. 创建NettyRpcEnv时会在内部维护一个timeoutScheduler,利用timeoutScheduler可以新启一个线程定时抛出那些超时任务的异常信息;
    6. 如果超时时间内消息处理成功,则取消定时抛出超时异常信息的线程任务;
    7. 返回处理结果。

    在这里, "消息接收方地址" 与 "当前发送消息的NettyRpcEnv的地址" 相同,因此投递消息到本地。

    3.3.1 投递消息到本地

    • 进入org.apache.spark.rpc.netty.Dispatcher.scala
      /** Posts a message sent by a local endpoint. */
      def postLocalMessage(message: RequestMessage, p: Promise[Any]): Unit = {
        val rpcCallContext =
          new LocalNettyRpcCallContext(message.senderAddress, p)
        val rpcMessage = RpcMessage(message.senderAddress, message.content, rpcCallContext)
        postMessage(message.receiver.name, rpcMessage, (e) => p.tryFailure(e))
      }
    
    1. 创建本地RPC调用上下文:LocalNettyRpcCallContext;
    2. 构建RpcMessage;
    3. 投递消息。
    • 进入org.apache.spark.rpc.netty.Dispatcher.scala
      /**
       * Posts a message to a specific endpoint.
       *
       * @param endpointName name of the endpoint.
       * @param message the message to post
       * @param callbackIfStopped callback function if the endpoint is stopped.
       */
      private def postMessage(
          endpointName: String,
          message: InboxMessage,
          callbackIfStopped: (Exception) => Unit): Unit = {
        val error = synchronized {
          val data = endpoints.get(endpointName)
          if (stopped) {
            Some(new RpcEnvStoppedException())
          } else if (data == null) {
            Some(new SparkException(s"Could not find $endpointName."))
          } else {
            data.inbox.post(message)
            receivers.offer(data)
            None
          }
        }
        // We don't need to call `onStop` in the `synchronized` block
        error.foreach(callbackIfStopped)
      }
    
    1. 根据 endpointName 从 Dispatcher 的 endpoints 中取出对应的 EndpointData(之前已经调用 RpcEnv.setupEndpoint 时注册到 Dispatcher 的 endpoints 中);
    2. 将消息内容加入到 EndpointData 的 Inbox 中;
    3. 将 EndpointData 放入到 receivers 中等待Dispatcher.MessageLoop 处理。

    这边也将消息放入到消息队列中了,在哪里取出来处理呢?见后文 "Dispatcher里消息处理过程"


    3.4 Dispatcher里消息处理过程

    在上面的过程中,创建EndpointData时同时会创建Inbox,在创建Inbox时又会将Onstart加入Inbox的内部队列messages,创建完的EndpointData会被放入到Dispatcher的内部队列receivers中,那么这两个队列中的内容在什么地方取出来处理呢?

    过程如下:

    1. new Dispatcher 时,会声明一个线程池
    • 进入org.apache.spark.rpc.netty.Dispatcher.scala
      /** Thread pool used for dispatching messages. */
      private val threadpool: ThreadPoolExecutor = {
        val availableCores =
          if (numUsableCores > 0) numUsableCores else Runtime.getRuntime.availableProcessors()
        val numThreads = nettyEnv.conf.getInt("spark.rpc.netty.dispatcher.numThreads",
          math.max(2, availableCores))
        val pool = ThreadUtils.newDaemonFixedThreadPool(numThreads, "dispatcher-event-loop")
        for (i <- 0 until numThreads) {
          pool.execute(new MessageLoop)
        }
        pool
      }
    
    1. MessageLoop继承了Runnable,循环不断的从 Dispatcher 的 receivers 中取出数据,即上面加入到receivers中的EndpointData,取出来的数据调用其内部的Inbox.process方法继续处理Inbox内的数据
    • 进入org.apache.spark.rpc.netty.Dispatcher.scala
      /** Message loop used for dispatching messages. */
      private class MessageLoop extends Runnable {
        override def run(): Unit = {
          try {
            while (true) {
              try {
                val data = receivers.take()
                if (data == PoisonPill) {
                  // Put PoisonPill back so that other MessageLoops can see it.
                  receivers.offer(PoisonPill)
                  return
                }
                data.inbox.process(Dispatcher.this)
              } catch {
                case NonFatal(e) => logError(e.getMessage, e)
              }
            }
          } catch {
            case _: InterruptedException => // exit
            case t: Throwable =>
              try {
                // Re-submit a MessageLoop so that Dispatcher will still work if
                // UncaughtExceptionHandler decides to not kill JVM.
                threadpool.execute(new MessageLoop)
              } finally {
                throw t
              }
          }
        }
      }
    
    1. 从Inbox内部队列messages中取出数据来处理(例如上面创建Inbox时加入的Onstart
    • 进入org.apache.spark.rpc.netty.Inbox.scala
      /**
       * Process stored messages.
       */
      def process(dispatcher: Dispatcher): Unit = {
        var message: InboxMessage = null
        inbox.synchronized {
          if (!enableConcurrent && numActiveThreads != 0) {
            return
          }
          message = messages.poll()
          if (message != null) {
            numActiveThreads += 1
          } else {
            return
          }
        }
        while (true) {
          safelyCall(endpoint) {
            message match {
              case RpcMessage(_sender, content, context) =>
                try {
                  endpoint.receiveAndReply(context).applyOrElse[Any, Unit](content, { msg =>
                    throw new SparkException(s"Unsupported message $message from ${_sender}")
                  })
                } catch {
                  case e: Throwable =>
                    context.sendFailure(e)
                    // Throw the exception -- this exception will be caught by the safelyCall function.
                    // The endpoint's onError function will be called.
                    throw e
                }
    
              case OneWayMessage(_sender, content) =>
                endpoint.receive.applyOrElse[Any, Unit](content, { msg =>
                  throw new SparkException(s"Unsupported message $message from ${_sender}")
                })
    
              case OnStart =>
                endpoint.onStart()
                if (!endpoint.isInstanceOf[ThreadSafeRpcEndpoint]) {
                  inbox.synchronized {
                    if (!stopped) {
                      enableConcurrent = true
                    }
                  }
                }
    
              case OnStop =>
                val activeThreads = inbox.synchronized { inbox.numActiveThreads }
                assert(activeThreads == 1,
                  s"There should be only a single active thread but found $activeThreads threads.")
                dispatcher.removeRpcEndpointRef(endpoint)
                endpoint.onStop()
                assert(isEmpty, "OnStop should be the last message")
    
              case RemoteProcessConnected(remoteAddress) =>
                endpoint.onConnected(remoteAddress)
    
              case RemoteProcessDisconnected(remoteAddress) =>
                endpoint.onDisconnected(remoteAddress)
    
              case RemoteProcessConnectionError(cause, remoteAddress) =>
                endpoint.onNetworkError(cause, remoteAddress)
            }
          }
    
          inbox.synchronized {
            // "enableConcurrent" will be set to false after `onStop` is called, so we should check it
            // every time.
            if (!enableConcurrent && numActiveThreads != 1) {
              // If we are not the only one worker, exit
              numActiveThreads -= 1
              return
            }
            message = messages.poll()
            if (message == null) {
              numActiveThreads -= 1
              return
            }
          }
        }
      }
    

    看看OnStart匹配的情况,会调用endpoint.onStart()方法。

    这意味着只要 RpcEndpoint 注册到 RpcEnv 上,就会向Dispatcher.Inbox 的内部队列中加入OnStart,那么后台线程就会取出OnStart处理,调用刚才注册的 RpcEndpoint 的 onStart() 方法。

    因此,在本文中会调用 Master.onStart() 方法:

    • 进入org.apache.spark.deploy.master.Master.scala
      override def onStart(): Unit = {
        logInfo("Starting Spark master at " + masterUrl)
        logInfo(s"Running Spark version ${org.apache.spark.SPARK_VERSION}")
        webUi = new MasterWebUI(this, webUiPort)
        webUi.bind()
        masterWebUiUrl = "http://" + masterPublicAddress + ":" + webUi.boundPort
        if (reverseProxy) {
          masterWebUiUrl = conf.get("spark.ui.reverseProxyUrl", masterWebUiUrl)
          webUi.addProxy()
          logInfo(s"Spark Master is acting as a reverse proxy. Master, Workers and " +
           s"Applications UIs are available at $masterWebUiUrl")
        }
        checkForWorkerTimeOutTask = forwardMessageThread.scheduleAtFixedRate(new Runnable {
          override def run(): Unit = Utils.tryLogNonFatalError {
            self.send(CheckForWorkerTimeOut)
          }
        }, 0, WORKER_TIMEOUT_MS, TimeUnit.MILLISECONDS)
    
        if (restServerEnabled) {
          val port = conf.getInt("spark.master.rest.port", 6066)
          restServer = Some(new StandaloneRestServer(address.host, port, conf, self, masterUrl))
        }
        restServerBoundPort = restServer.map(_.start())
    
        masterMetricsSystem.registerSource(masterSource)
        masterMetricsSystem.start()
        applicationMetricsSystem.start()
        // Attach the master and app metrics servlet handler to the web ui after the metrics systems are
        // started.
        masterMetricsSystem.getServletHandlers.foreach(webUi.attachHandler)
        applicationMetricsSystem.getServletHandlers.foreach(webUi.attachHandler)
    
        val serializer = new JavaSerializer(conf)
        val (persistenceEngine_, leaderElectionAgent_) = RECOVERY_MODE match {
          case "ZOOKEEPER" =>
            logInfo("Persisting recovery state to ZooKeeper")
            val zkFactory =
              new ZooKeeperRecoveryModeFactory(conf, serializer)
            (zkFactory.createPersistenceEngine(), zkFactory.createLeaderElectionAgent(this))
          case "FILESYSTEM" =>
            val fsFactory =
              new FileSystemRecoveryModeFactory(conf, serializer)
            (fsFactory.createPersistenceEngine(), fsFactory.createLeaderElectionAgent(this))
          case "CUSTOM" =>
            val clazz = Utils.classForName(conf.get("spark.deploy.recoveryMode.factory"))
            val factory = clazz.getConstructor(classOf[SparkConf], classOf[Serializer])
              .newInstance(conf, serializer)
              .asInstanceOf[StandaloneRecoveryModeFactory]
            (factory.createPersistenceEngine(), factory.createLeaderElectionAgent(this))
          case _ =>
            (new BlackHolePersistenceEngine(), new MonarchyLeaderAgent(this))
        }
        persistenceEngine = persistenceEngine_
        leaderElectionAgent = leaderElectionAgent_
      }
    

    4. 总结

    1. Master启动时,创建NettyRpcEnv;
      1.1 创建NettyRpcEnv时会在其内部创建Dispatcher;
      1.2 创建Dispatcher时会在其内部创建ConcurrentMap[String, EndpointData] 和 LinkedBlockingQueue[EndpointData];

    2. 创建RpcEndpoint(Master);

    3. 注册RpcEndpoint(Master)到Dispatcher上;
      3.1 注册RpcEndpoint(Master)到Dispatcher时,会先创建NettyRpcEndpointRef(masterEndpoint),
      3.2 然后构建EndpointData(name, endpoint, endpointRef),在构建EndpointData时会在其内部创建Inbox,创建Inbox时会在其内部创建 LinkedList[InboxMessage],同时加入OnStart消息到队列中,
      3.3 然后将(name, EndpointData)放入ConcurrentMap[String, EndpointData],
      3.4 将EndpointData放入LinkedBlockingQueue[EndpointData];

    4. 返回对应的NettyRpcEndpointRef(masterEndpoint);

    5. NettyRpcEndpointRef(masterEndpoint) 发消息;

    6. 构建RequestMessage(senderAddress:RpcAddress, receiver:NettyRpcEndpointRef, content:Any),调用NettyRpcEnv发消息;

    7. NettyRpcEnv发消息时,判断消息接收方地址(receiver)和当前NettyRpcEnv的地址是否相同,
      7.1 如果相同则调用Dispatcher投递消息到本地,
      7.2 如果不同则需要调用Dispatcher把消息投递到远程NettyRpcEnv;

    8. 投递消息到本地时,从receiver中获取endpointName,然后根据此endpointName从Dispatcher的ConcurrentMap中获取EndpointData,
      8.1 将消息加入EndpointData.Inbox.LinkedList[InboxMessage]中,
      8.2 将EndpointData加入Dispatcher.LinkedBlockingQueue[EndpointData]中;

    9. Dispatcher.MessageLoop不断从上面的两个队列中取数据出来处理。

    相关文章

      网友评论

          本文标题:Spark源码:启动Master

          本文链接:https://www.haomeiwen.com/subject/dxuenqtx.html