美文网首页
spark源码阅读——rpc部分

spark源码阅读——rpc部分

作者: WJL3333 | 来源:发表于2018-07-14 01:42 被阅读25次

    rpc可以说是一个分布式系统最基础的组件了。这里解析一下spark的内部rpc框架。

    RpcEndpoint

    RpcEndpoint 这个接口表示一个Rpc端点,只要继承了这个trait
    就具备了收发Rpc消息的能力,主要包含以下方法

    • 接收信息类

      • def receive: PartialFunction[Any, Unit] 一个偏函数,用来接受其他RpcEndpoint发来的信息,其他类可以覆盖这个方法来重写接受信息的逻辑

      • def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] 方法和上面那个差不多,不过这个处理过逻辑之后可以返回一些信息

    • 回调类

      • def onConnected(remoteAddress: RpcAddress): Unit 当有远程主机连接到这个RpcEndpoint时的回调
      • onStart,onStop,onDisconnected等回调

    RpcEndpointRef

    RpcEndpointRef表示了一个远程RpcEndpoint和当前端点的一个连接,如果想发送RPC消息给其他主机,可以先通过远程地址RpcAddress(一个表示远程端点的case class)获取RpcEndpointRef对象。通过这个对象发送RPC消息给远程节点。主要包括以下方法

    • 异步发送请求 def ask[T: ClassTag](message: Any, timeout: RpcTimeout): Future[T]
      这个方法发送任意的消息给远程端点,并返回一个Future对象。当远端返回信息的时候可以从这个对象获取结果。

    • 同步发送请求 def askSync[T: ClassTag](message: Any, timeout: RpcTimeout): T 等待直到返回结果

    • 只发送信息 def send(message: Any): Unit

    RpcEnv

    这个接口可以说非常重要了,保存了所有的远程端点信息,而且负责RPC消息的分发。每一个RpcEndpoint都有一个RpcEnv对象。如果想要与其他RpcEndpoint连接并收发信息,需要向远端RpcEndpoint注册自己,远端RpcEndpoint收到注册信息之后,会将请求连接的信息保存在RpcEnv对象中,这样就算是两个RpcEndpoint彼此连接上了(可以双向收发信息了)

    • Endpoint的注册方法

      • def setupEndpoint(name: String, endpoint: RpcEndpoint): RpcEndpointRef
        用来一个Endpoint把自己注册到本地的RpcEnv中。一个进程可能有多个Endpoint 比如说一个接收心跳信息的,还有一个用来监听Job的运行状态的,用来监听Executor返回信息的等等。
        RpcEndpoint通过RpcEnv发送信息给RpcEndpointRef
        RpcEnv内部将接收到的信息分发给注册在RpcEnv中的RpcEndpoint

      • def asyncSetupEndpointRefByURI(uri: String): Future[RpcEndpointRef] 异步注册

      • def setupEndpointRef(address: RpcAddress, endpointName: String): RpcEndpointRef 同步注册

    • 生命周期方法

      • stop
      • shutdown
      • awaitTermination

    RpcCallContext

    下面分析时会说,先贴出方法

    private[spark] trait RpcCallContext {
    
      /**
       * Reply a message to the sender. If the sender is [[RpcEndpoint]], its [[RpcEndpoint.receive]]
       * will be called.
       */
      def reply(response: Any): Unit
    
      /**
       * Report a failure to the sender.
       */
      def sendFailure(e: Throwable): Unit
    
      /**
       * The sender of this message.
       */
      def senderAddress: RpcAddress
    }
    

    spark 中使用了Netty实现了这些Rpc接口,下面看一看使用netty的实现。

    NettyRpcEnvFactory

    private[rpc] class NettyRpcEnvFactory extends RpcEnvFactory with Logging {
    
      def create(config: RpcEnvConfig): RpcEnv = {
        val sparkConf = config.conf
        // Use JavaSerializerInstance in multiple threads is safe. However, if we plan to support
        // KryoSerializer in future, we have to use ThreadLocal to store SerializerInstance
        val javaSerializerInstance =
          new JavaSerializer(sparkConf).newInstance().asInstanceOf[JavaSerializerInstance]
        val nettyEnv =
          new NettyRpcEnv(sparkConf, javaSerializerInstance, config.advertiseAddress,
            config.securityManager)
        if (!config.clientMode) {
          val startNettyRpcEnv: Int => (NettyRpcEnv, Int) = { actualPort =>
            nettyEnv.startServer(config.bindAddress, actualPort)
            (nettyEnv, nettyEnv.address.port)
          }
          try {
            Utils.startServiceOnPort(config.port, startNettyRpcEnv, sparkConf, config.name)._1
          } catch {
            case NonFatal(e) =>
              nettyEnv.shutdown()
              throw e
          }
        }
        nettyEnv
      }
    }
    

    用来创建NettyRpcEnv对象一个工厂,创建了一个NettyRpcEnv对象。
    并启动了一个Netty服务器(nettyEnv.startServer方法)

    NettyRpcEnv

    这个对象主要包含了一个Dispatcher

    private[netty] class NettyRpcEnv(
        val conf: SparkConf,
        javaSerializerInstance: JavaSerializerInstance,
        host: String,
        securityManager: SecurityManager) extends RpcEnv(conf) with Logging {
    
      ...
      private val dispatcher: Dispatcher = new Dispatcher(this)
      ...
      private val transportContext = new TransportContext(transportConf,
        new NettyRpcHandler(dispatcher, this, streamManager))
      ...
      @volatile private var server: TransportServer = _
      private val outboxes = new ConcurrentHashMap[RpcAddress, Outbox]()
      ... 
    
      def startServer(bindAddress: String, port: Int): Unit = {
            .....
            server = transportContext.createServer(bindAddress, port, bootstraps)
            dispatcher.registerRpcEndpoint(
            RpcEndpointVerifier.NAME, new RpcEndpointVerifier(this, dispatcher))
      }
    }
    
    

    上面说到调用了startServer方法
    而这个方法内部则向dispatcher对象注册了一个RpcEndpointVerifier,这个对象其实也是一个RpcEndpoint

    private[netty] class RpcEndpointVerifier(override val rpcEnv: RpcEnv, dispatcher: Dispatcher)
      extends RpcEndpoint {
    
      override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = {
        case RpcEndpointVerifier.CheckExistence(name) => context.reply(dispatcher.verify(name))
      }
    }
    
    private[netty] object RpcEndpointVerifier {
      val NAME = "endpoint-verifier"
    
      /** A message used to ask the remote [[RpcEndpointVerifier]] if an `RpcEndpoint` exists. */
      case class CheckExistence(name: String)
    }
    

    这里便是我们遇到的第一个RpcEndpoint 如果收到了CheckExistence这个类型的信息则调用dispatcherverify方法。

    我们先看一下这个dispatcher对象。

    Dispatcher

    这个对象的职责便是将收到的Rpc信息分发给不同的Endpoint,可以看到内部有一个ConcurrentHashMap用来保存所有注册的RpcEndpoint

    private[netty] class Dispatcher(nettyEnv: NettyRpcEnv) extends Logging {
    
      private class EndpointData(
          val name: String,
          val endpoint: RpcEndpoint,
          val ref: NettyRpcEndpointRef) {
        val inbox = new Inbox(ref, endpoint)
      }
    
      private val endpoints: ConcurrentMap[String, EndpointData] =
        new ConcurrentHashMap[String, EndpointData]
    
      private val receivers = new LinkedBlockingQueue[EndpointData]
      ....
    
    }
    

    上面说到的registerRpcEndpoint方法实际上将RpcEndpointVerifier放入了这两个容器中。
    RpcEndpointVerifier则被其他Endpoint用来判断自己是否被成功注册到这个RpcEnv中。
    远程Endpoint发送一个包含自己名字的信息给这个RpcEnv中的这个RpcEndpointVerifier随后会检查保存Endpoint信息的容器中是否包含注册信息,并将结果返回

    NettyRpcEndpointRef

    前面说过RpcEndpointRef代表远端的Endpoint,可以用来发送RPC信息

    
    private[netty] class NettyRpcEndpointRef(
        @transient private val conf: SparkConf,
        private val endpointAddress: RpcEndpointAddress,
        @transient @volatile private var nettyEnv: NettyRpcEnv) extends RpcEndpointRef(conf) {
        ...
    
        override def ask[T: ClassTag](message: Any, timeout: RpcTimeout): Future[T] = {
            nettyEnv.ask(new RequestMessage(nettyEnv.address, this, message), timeout)
        }
    }
    

    让我们回到RpcEnv.ask方法

    private[netty] def ask[T: ClassTag](message: RequestMessage, timeout: RpcTimeout): Future[T] = {
        val promise = Promise[Any]()
        val remoteAddr = message.receiver.address
    
        def onFailure(e: Throwable): Unit = { ... }
        def onSuccess(reply: Any): Unit = reply match { ... }
    
        try {
          if (remoteAddr == address) {
            val p = Promise[Any]()
            p.future.onComplete {
              case Success(response) => onSuccess(response)
              case Failure(e) => onFailure(e)
            }(ThreadUtils.sameThread)
            dispatcher.postLocalMessage(message, p)
          } else {
            val rpcMessage = RpcOutboxMessage(message.serialize(this),
              onFailure,
              (client, response) => onSuccess(deserialize[Any](client, response)))
            postToOutbox(message.receiver, rpcMessage)
            promise.future.onFailure {
              case _: TimeoutException => rpcMessage.onTimeout()
              case _ =>
            }(ThreadUtils.sameThread)
          }
    
          val timeoutCancelable = timeoutScheduler.schedule(new Runnable {
            override def run(): Unit = {
              onFailure(new TimeoutException(s"Cannot receive any reply from ${remoteAddr} " +
                s"in ${timeout.duration}"))
            }
          }, timeout.duration.toNanos, TimeUnit.NANOSECONDS)
          promise.future.onComplete { v =>
            timeoutCancelable.cancel(true)
          }(ThreadUtils.sameThread)
        } catch { ... }
        promise.future.mapTo[T].recover(timeout.addMessageIfTimeout)(ThreadUtils.sameThread)
      }
    

    这个方法由3部分构成
    第一部分:判断消息是否是发给本地注册的RpcEndpoint的,是则发送本地信息
    第二部分:如果是发给远程Endpoint的,放到OutBox里面,等待处理
    第三部分:超时处理,起了一个定时任务,如果超时则报异常。同时给声明的Promise对象增加了一个回调,当rpc调用在超时前完成则取消之前起的定时任务。

    我们首先看dispatcher.postLocalMessage,这个方法封装了调用信息,

    def postLocalMessage(message: RequestMessage, p: Promise[Any]): Unit = {
        val rpcCallContext =
          new LocalNettyRpcCallContext(message.senderAddress, p)
        val rpcMessage = RpcMessage(message.senderAddress, message.content, rpcCallContext)
        postMessage(message.receiver.name, rpcMessage, (e) => p.tryFailure(e))
      }
    

    实际上走了dispatcher.postMessage方法,实际做了3件事:

    1.获取到EndpointData对象
    2.往这个对象的inbox对象发信息
    3.将EndpointData对象放入 receivers队列中

           
    private def postMessage(
          endpointName: String,
          message: InboxMessage,
          callbackIfStopped: (Exception) => Unit): Unit ={
           ...
          val data = endpoints.get(endpointName)
          data.inbox.post(message)
          receivers.offer(data)
           ...
    }
    

    inbox对象实际就保存了发往Endpoint对象的信息。发到这里其实Endpoint 已经收到信息了。 但是post方法只是将消息放到队列里面,那么实际是怎么发送给Endpoint的呢?

    private[netty] class Inbox(
        val endpointRef: NettyRpcEndpointRef,
        val endpoint: RpcEndpoint)
      extends Logging {
    
      inbox =>  // Give this an alias so we can use it more clearly in closures.
    
      @GuardedBy("this")
      protected val messages = new java.util.LinkedList[InboxMessage]()
      ...
     
      def post(message: InboxMessage): Unit = inbox.synchronized {
        if (stopped) {
          // We already put "OnStop" into "messages", so we should drop further messages
          onDrop(message)
        } else {
          messages.add(message)
          false
        }
      ...
      }
    

    Dispatcher对象里面有一个线程池,每个线程会不断的从receivers队列中获取EndpointData并处理其中的inbox对象保存的信息

    private val threadpool: ThreadPoolExecutor = {
        val numThreads = nettyEnv.conf.getInt("spark.rpc.netty.dispatcher.numThreads",
          math.max(2, Runtime.getRuntime.availableProcessors()))
        val pool = ThreadUtils.newDaemonFixedThreadPool(numThreads, "dispatcher-event-loop")
        for (i <- 0 until numThreads) {
          pool.execute(new MessageLoop)
        }
        pool
      }
    
    private class MessageLoop extends Runnable {
        override def run(): Unit = {
          try {
            while (true) {
              try {
                val data = receivers.take()
                if (data == PoisonPill) {
                  // Put PoisonPill back so that other MessageLoops can see it.
                  receivers.offer(PoisonPill)
                  return
                }
                data.inbox.process(Dispatcher.this)
              } catch {
                case NonFatal(e) => logError(e.getMessage, e)
              }
            }
          } catch {
            case ie: InterruptedException => // exit
          }
        }
      }
    

    我们再回到inbox.process方法

    def process(dispatcher: Dispatcher): Unit = {
        var message: InboxMessage = null
        inbox.synchronized {
          ... 
          message = messages.poll()
          ...
        }
        while (true) {
          safelyCall(endpoint) {
            message match {
              case RpcMessage(_sender, content, context) =>
                try {
                  endpoint.receiveAndReply(context).applyOrElse[Any, Unit](content, { msg =>
                    throw new SparkException(s"Unsupported message $message from ${_sender}")
                  })
                } catch { ... }
    
              case OneWayMessage(_sender, content) =>
                endpoint.receive.applyOrElse[Any, Unit](content, { msg =>
                  throw new SparkException(s"Unsupported message $message from ${_sender}")
                })
    
              case OnStart =>
                endpoint.onStart()
                if (!endpoint.isInstanceOf[ThreadSafeRpcEndpoint]) {
                  inbox.synchronized {
                    if (!stopped) {
                      enableConcurrent = true
                    }
                  }
                }
    
              case OnStop =>
                val activeThreads = inbox.synchronized { inbox.numActiveThreads }
                ...
                dispatcher.removeRpcEndpointRef(endpoint)
                endpoint.onStop()
                ...
    
              case RemoteProcessConnected(remoteAddress) =>
                endpoint.onConnected(remoteAddress)
    
              case RemoteProcessDisconnected(remoteAddress) =>
                endpoint.onDisconnected(remoteAddress)
    
              case RemoteProcessConnectionError(cause, remoteAddress) =>
                endpoint.onNetworkError(cause, remoteAddress)
            }
          }
    
          inbox.synchronized {
            ... 
            message = messages.poll()
            if (message == null) {
              numActiveThreads -= 1
              return
            }
          }
        }
      }
    

    可以看到这个方法不停的从messages队列中获取对象直到队列里面没有信息
    之前发送给本地的Endpoint的消息是InboxMessage这个对应的模式匹配中的哪个对象呢?

    private[netty] sealed trait InboxMessage
    
    private[netty] case class OneWayMessage(
        senderAddress: RpcAddress,
        content: Any) extends InboxMessage
    
    private[netty] case class RpcMessage(
        senderAddress: RpcAddress,
        content: Any,
        context: NettyRpcCallContext) extends InboxMessage
    
    private[netty] case object OnStart extends InboxMessage
    
    private[netty] case object OnStop extends InboxMessage
    

    之前发送的本地消息是RpcMessage类型的,InboxEndpoint是一一对应的,所以会直接调用endpoint.receiveAndReply方法进行相应的处理,也就是说这时候消息已经发送到Endpoint了。(可以参考RpcEndpointVerifier.receiveAndReply,这是其中一种RpcEndpoint,在这个流程中可以理解为,本地的RpcEndpoint向本地的RpcEnv确认是否成功注册)

    那么我们看一下发送消息给远程的RpcEndpoint消息被封装成RpcOutboxMessage,并调用了postToOutbox方法

    private def postToOutbox(receiver: NettyRpcEndpointRef, message: OutboxMessage): Unit = {
        if (receiver.client != null) {
          message.sendWith(receiver.client)
        } else {
          ...
          val targetOutbox = {
            val outbox = outboxes.get(receiver.address)
            ...
          }
          if (stopped.get) { ... } else {
            targetOutbox.send(message)
          }
        }
      }
    
    private[netty] class Outbox(nettyEnv: NettyRpcEnv, val address: RpcAddress) {
    outbox => // Give this an alias so we can use it more clearly in closures.
    
      @GuardedBy("this")
      private val messages = new java.util.LinkedList[OutboxMessage]
    
      @GuardedBy("this")
      private var client: TransportClient = null
    
      @GuardedBy("this")
      private var connectFuture: java.util.concurrent.Future[Unit] = null
    
      def send(message: OutboxMessage): Unit = {
        val dropped = synchronized {
          if (stopped) { ... } else {
            messages.add(message)
            false
          }
        }
        if (dropped) { ... } else {
          drainOutbox()
        }
      }
     
    

    每个Outbox里面包含

    • 一个保存消息的队列
    • 一个TransportClient 连接远程的RpcEndpoint并用来发送信息

    drainOutbox方法实际做了2件事

    1. 检查是否和远端的 RpcEndpoint建立了连接,没有则起一个线程建立连接
    2. 遍历队列,发送信息给远端的RpcEnvTransportServer这个信息会被远端的 NettyRpcHandler处理
    private[netty] class NettyRpcHandler(
        dispatcher: Dispatcher,
        nettyEnv: NettyRpcEnv,
        streamManager: StreamManager) extends RpcHandler with Logging {
    
      // A variable to track the remote RpcEnv addresses of all clients
      private val remoteAddresses = new ConcurrentHashMap[RpcAddress, RpcAddress]()
    
      override def receive(
          client: TransportClient,
          message: ByteBuffer,
          callback: RpcResponseCallback): Unit = {
        val messageToDispatch = internalReceive(client, message)
        dispatcher.postRemoteMessage(messageToDispatch, callback)
      }
    }
    
    def postRemoteMessage(message: RequestMessage, callback: RpcResponseCallback): Unit = {
        val rpcCallContext =
          new RemoteNettyRpcCallContext(nettyEnv, callback, message.senderAddress)
        val rpcMessage = RpcMessage(message.senderAddress, message.content, rpcCallContext)
        postMessage(message.receiver.name, rpcMessage, (e) => callback.onFailure(e))
      }
    

    于是我们又看到了postMesage这个方法,而这次是调用的远端的RpcEnvDispatcherpostMessage,消息最后也会被发送给注册到远端的RpcEnv中的RpcEndpoint,这样远端的RpcEndpoint便收到了来自本地的信息。完成了RPC通信。

    相关文章

      网友评论

          本文标题:spark源码阅读——rpc部分

          本文链接:https://www.haomeiwen.com/subject/mjcepftx.html