pyspark与py4j线程模型简析

作者: Garfieldog | 来源:发表于2017-04-08 18:12 被阅读1100次

    事由

    上周工作中遇到一个bug,现象是一个spark streaming的job会不定期地hang住,不退出也不继续运行。这个job经是用pyspark写的,以kafka为数据源,会在每个batch结束时将统计结果写入mysql。经过排查,我们在driver进程中发现有有若干线程都出于Sl状态(睡眠状态),进而使用gdb调试发现了一处死锁。

    这是MySQLdb库旧版本中的一处bug,在此不再赘述,有兴趣的可以看这个issue。不过这倒是提起了我对另外一件事的兴趣,就是driver进程——严格的说应该是driver进程的python子进程——中的这些线程是从哪来的?当然,这些线程的存在很容易理解,我们开启了spark.streaming.concurrentJobs参数,有多个batch可以同时执行,每个线程对应一个batch。但翻遍pyspark的python代码,都没有找到有相关线程启动的地方,于是简单调研了一下pyspark到底是怎么工作的,做个记录。

    本文概括

    1. Py4J的线程模型
    2. pyspark基本原理(driver端)
    3. CPython中的deque的线程安全

    涉及软件版本

    • spark: 2.1.0
    • py4j: 0.10.4

    Py4J

    spark是由scala语言编写的,pyspark并没有像豆瓣开源的dpark用python复刻了spark,而只是提供了一层可以与原生JVM通信的python API,Py4J就是python与JVM之间的这座桥梁。这个库分为Java和Python两部分,基本原理是:

    1. Java部分,通过py4j.GatewayServer监听一个tcp socket(记做server_socket)
    2. Python部分,所有对JVM中对象的访问或者方法的调用,都是通过py4j.JavaGateway向上面这个socket完成的。
    3. 另外,Python部分在创建JavaGateway对象时,可以选择同时创建一个CallbackServer,它会在Python这册监听一个tcp socket(记做callback_socket),用来给Java回调Python代码提供一条渠道。
    4. Py4J提供了一套文本协议用来在tcp socket间传递命令。

    pyspark driver工作流程

    1. 首先,一个spark job被提交后,如果被判定这是一个python的job,spark driver会找到相应的入口,即org.apache.spark.deploy.PythonRunnermain函数,这个函数中会启动GatewayServer
        // Launch a Py4J gateway server for the process to connect to; this will let it see our
        // Java system properties and such
        val gatewayServer = new py4j.GatewayServer(null, 0)
        val thread = new Thread(new Runnable() {
          override def run(): Unit = Utils.logUncaughtExceptions {
            gatewayServer.start()
          }
        })
        thread.setName("py4j-gateway-init")
        thread.setDaemon(true)
        thread.start()
    
    1. 然后,会创建一个Python子进程来运行我们提交上来的python入口文件,并把刚才GatewayServer监听的那个端口写入到子进程的环境变量中去(这样Python才知道要通过那个端口访问JVM)
        // Launch Python process
        val builder = new ProcessBuilder((Seq(pythonExec, formattedPythonFile) ++ otherArgs).asJava)
        val env = builder.environment()
        env.put("PYTHONPATH", pythonPath)
        // This is equivalent to setting the -u flag; we use it because ipython doesn't support -u:
        env.put("PYTHONUNBUFFERED", "YES") // value is needed to be set to a non-empty string
        env.put("PYSPARK_GATEWAY_PORT", "" + gatewayServer.getListeningPort)
        // pass conf spark.pyspark.python to python process, the only way to pass info to
        // python process is through environment variable.
        sparkConf.get(PYSPARK_PYTHON).foreach(env.put("PYSPARK_PYTHON", _))
        builder.redirectErrorStream(true) // Ugly but needed for stdout and stderr to synchronize
    
    1. Python子进程这边,我们是通过pyspark提供的python API编写的这个程序,在创建SparkContext(python)时,会初始化_gateway变量(JavaGateway对象)和_jvm变量(JVMView对象)
        @classmethod
        def _ensure_initialized(cls, instance=None, gateway=None, conf=None):
            """
            Checks whether a SparkContext is initialized or not.
            Throws error if a SparkContext is already running.
            """
            with SparkContext._lock:
                if not SparkContext._gateway:
                    SparkContext._gateway = gateway or launch_gateway(conf)
                    SparkContext._jvm = SparkContext._gateway.jvm
    
                if instance:
                    if (SparkContext._active_spark_context and
                            SparkContext._active_spark_context != instance):
                        currentMaster = SparkContext._active_spark_context.master
                        currentAppName = SparkContext._active_spark_context.appName
                        callsite = SparkContext._active_spark_context._callsite
    
                        # Raise error if there is already a running Spark context
                        raise ValueError(
                            "Cannot run multiple SparkContexts at once; "
                            "existing SparkContext(app=%s, master=%s)"
                            " created by %s at %s:%s "
                            % (currentAppName, currentMaster,
                                callsite.function, callsite.file, callsite.linenum))
                    else:
                        SparkContext._active_spark_context = instance
    

    其中launch_gateway函数可见pyspark/java_gateway.py

    1. 上面初始化的这个_jvm对象值得一说,在pyspark中很多对JVM的调用其实都是通过它来进行的,比如很多python种对应的spark对象都有一个_jsc变量,它是JVM中的SparkContext对象在Python中的影子,它是这么初始化的
        def _initialize_context(self, jconf):
            """
            Initialize SparkContext in function to allow subclass specific initialization
            """
            return self._jvm.JavaSparkContext(jconf)
    

    这里_jvm为什么能直接调用JavaSparkContext这个JVM环境中的构造函数呢?我们看JVMView中的__getattr__方法:

        def __getattr__(self, name):
            if name == UserHelpAutoCompletion.KEY:
                return UserHelpAutoCompletion()
    
            answer = self._gateway_client.send_command(
                proto.REFLECTION_COMMAND_NAME +
                proto.REFL_GET_UNKNOWN_SUB_COMMAND_NAME + name + "\n" + self._id +
                "\n" + proto.END_COMMAND_PART)
            if answer == proto.SUCCESS_PACKAGE:
                return JavaPackage(name, self._gateway_client, jvm_id=self._id)
            elif answer.startswith(proto.SUCCESS_CLASS):
                return JavaClass(
                    answer[proto.CLASS_FQN_START:], self._gateway_client)
            else:
                raise Py4JError("{0} does not exist in the JVM".format(name))
    

    self._gateway_client.send_command其实就是向server_socket发送访问对象请求的命令了,最后根据响应值生成不同类型的影子对象,针对我们这里的JavaSparkContext,就是一个JavaClass对象。这个系列的类型还包括了JavaMemberJavaPackage等等,他们也通过__getattr__来实现Java对象属性访问以及方法的调用。

    1. 我们刚才介绍Py4j时说过Python端在创建JavaGateway时,可以选择同时创建一个CallbackClient,默认情况下,一个普通的pyspark job是不会启动回调服务的,因为用不着,所有的交互都是Python --> JVM这种模式的。那什么时候需要呢?streaming job就需要(具体流程我们稍后介绍),这就(终于!)引出了我们今天主要讨论的Py4J线程模型的问题。

    Py4J线程模型

    我们已经知道了Python与JVM双方向的通信分别是通过server_socketcallack_socket来完成的,这两个socket的处理模型都是多线程模型,即,每收到一个连接就启动一个线程来处理。我们只看Python --> JVM这条通路的情况,另外一边是一样的

    Server端(Java)

        protected void processSocket(Socket socket) {
            try {
                this.lock.lock();
                if(!this.isShutdown) {
                    socket.setSoTimeout(this.readTimeout);
                    Py4JServerConnection gatewayConnection = this.createConnection(this.gateway, socket);
                    this.connections.add(gatewayConnection);
                    this.fireConnectionStarted(gatewayConnection);
                }
            } catch (Exception var6) {
                this.fireConnectionError(var6);
            } finally {
                this.lock.unlock();
            }
        }
    

    继续看createConnection:

        protected Py4JServerConnection createConnection(Gateway gateway, Socket socket) throws IOException {
            GatewayConnection connection = new GatewayConnection(gateway, socket, this.customCommands, this.listeners);
            connection.startConnection();
            return connection;
        }
    

    其中connection.startConnection其实就是创建了一个新线程,来负责处理这个连接。

    Client端(Python)

    我们来看GatewayClient中的send_command方法:

        def send_command(self, command, retry=True, binary=False):
            """Sends a command to the JVM. This method is not intended to be
               called directly by Py4J users. It is usually called by
               :class:`JavaMember` instances.
    
            :param command: the `string` command to send to the JVM. The command
             must follow the Py4J protocol.
    
            :param retry: if `True`, the GatewayClient tries to resend a message
             if it fails.
    
            :param binary: if `True`, we won't wait for a Py4J-protocol response
             from the other end; we'll just return the raw connection to the
             caller. The caller becomes the owner of the connection, and is
             responsible for closing the connection (or returning it this
             `GatewayClient` pool using `_give_back_connection`).
    
            :rtype: the `string` answer received from the JVM (The answer follows
             the Py4J protocol). The guarded `GatewayConnection` is also returned
             if `binary` is `True`.
            """
            connection = self._get_connection()
            try:
                response = connection.send_command(command)
                if binary:
                    return response, self._create_connection_guard(connection)
                else:
                    self._give_back_connection(connection)
            except Py4JNetworkError as pne:
                if connection:
                    reset = False
                    if isinstance(pne.cause, socket.timeout):
                        reset = True
                    connection.close(reset)
                if self._should_retry(retry, connection, pne):
                    logging.info("Exception while sending command.", exc_info=True)
                    response = self.send_command(command, binary=binary)
                else:
                    logging.exception(
                        "Exception while sending command.")
                    response = proto.ERROR
    
            return response
    

    这里这个self._get_connection是这么实现的

        def _get_connection(self):
            if not self.is_connected:
                raise Py4JNetworkError("Gateway is not connected.")
            try:
                connection = self.deque.pop()
            except IndexError:
                connection = self._create_connection()
            return connection
    

    这里使用了一个deque(也就是Python标准库中的collections.deque)来维护一个连接池,如果有空闲的连接,就可以直接使用,如果没有,就新建一个连接。现在问题来了,如果deque不是线程安全的,那么这段代码在多线程环境就会有问题。那么deque是不是线程安全的呢?

    deque的线程安全

    当然是了,Py4J当然不会犯这样的低级错误,我们看标准库的文档:

    Deques support thread-safe, memory efficient appends and pops from either side of the deque with approximately the same O(1) performance in either direction.

    是线程安全的,不过措辞有点模糊,没有明确指出哪些方法是线程安全的,不过可以明确的是至少append的pop都是。之所以去查一下,是因为我也有点含糊,因为Python标准库还有另外一个Queue.Queue,在多线程编程中经常使用,肯定是线程安全的,于是很容易误以为deque不是线程安全的,所以我们才要一个新的Queue。这个问题,推荐阅读stackoverflow上Jonathan的这个答案——他的回答不是被采纳的最高票,不过我认为他的回答比高票更有说服力

    1. 高票答案一直强调说deque是线程安全的这个事实是个意外,是CPython中存在GIL造成的,其他Python解释器就不一定遵守。关于这一点我是不认同的,deque在CPython中的实现确实依赖的GIL才变成了线程安全的,但deque的双端append的pop是线程安全的这件事是白纸黑字写在Python文档中的,其他虚拟机的实现必须遵守,否则就不能称之为合格的Python实现。
    2. 那为什么还要有一个内部显式用了锁来做线程同步的Queue.Queue呢?Jonathan给出的回答是Queueputget可以是blocking的,而deque不行,这样一来,当你需要在多个线程中进行通信时(比如最简单的一个Producer - Consumer模式的实现),Queue往往是最佳选择。

    关于deque是否是线程安全这个问题,我将调研的结果写在了这个知乎问题的答案下Python中的deque是线程安全的吗?,就不在赘述了,这篇文章已经太长了。

    关于Py4J线程模型的问题,还可以参考官方文档中的解释

    pyspark streaming与CallbackServer

    刚才提到,如果是streaming的job,GatewayServer在初始化时会同时创建一个CallbackServer,提供JVM --> Python这条通路。

        @classmethod
        def _ensure_initialized(cls):
            SparkContext._ensure_initialized()
            gw = SparkContext._gateway
    
            java_import(gw.jvm, "org.apache.spark.streaming.*")
            java_import(gw.jvm, "org.apache.spark.streaming.api.java.*")
            java_import(gw.jvm, "org.apache.spark.streaming.api.python.*")
    
            # start callback server
            # getattr will fallback to JVM, so we cannot test by hasattr()
            if "_callback_server" not in gw.__dict__ or gw._callback_server is None:
                gw.callback_server_parameters.eager_load = True
                gw.callback_server_parameters.daemonize = True
                gw.callback_server_parameters.daemonize_connections = True
                gw.callback_server_parameters.port = 0
                gw.start_callback_server(gw.callback_server_parameters)
                cbport = gw._callback_server.server_socket.getsockname()[1]
                gw._callback_server.port = cbport
                # gateway with real port
                gw._python_proxy_port = gw._callback_server.port
                # get the GatewayServer object in JVM by ID
                jgws = JavaObject("GATEWAY_SERVER", gw._gateway_client)
                # update the port of CallbackClient with real port
                jgws.resetCallbackClient(jgws.getCallbackClient().getAddress(), gw._python_proxy_port)
    
            # register serializer for TransformFunction
            # it happens before creating SparkContext when loading from checkpointing
            cls._transformerSerializer = TransformFunctionSerializer(
                SparkContext._active_spark_context, CloudPickleSerializer(), gw)
    

    为什么需要这样呢?一个streaming job通常需要调用foreachRDD,并提供一个函数,这个函数会在每个batch被回调:

        def foreachRDD(self, func):
            """
            Apply a function to each RDD in this DStream.
            """
            if func.__code__.co_argcount == 1:
                old_func = func
                func = lambda t, rdd: old_func(rdd)
            jfunc = TransformFunction(self._sc, func, self._jrdd_deserializer)
            api = self._ssc._jvm.PythonDStream
            api.callForeachRDD(self._jdstream, jfunc)
    

    这里,Python函数func被封装成了一个TransformFunction对象,在scala端spark也定义了同样接口一个trait:

    /**
     * Interface for Python callback function which is used to transform RDDs
     */
    private[python] trait PythonTransformFunction {
      def call(time: Long, rdds: JList[_]): JavaRDD[Array[Byte]]
    
      /**
       * Get the failure, if any, in the last call to `call`.
       *
       * @return the failure message if there was a failure, or `null` if there was no failure.
       */
      def getLastFailure: String
    }
    

    这样是Py4J提供的机制,这样就可以让JVM通过这个影子接口回调Python中的对象了,下面就是scala中的callForeachRDD函数,它把PythonTransformFunction又封装了一层成为scala中的TransformFunction, 但不管如何封装,最后都会调用PythonTransformFunction接口中的call方法完成对Python的回调。

      /**
       * helper function for DStream.foreachRDD(),
       * cannot be `foreachRDD`, it will confusing py4j
       */
      def callForeachRDD(jdstream: JavaDStream[Array[Byte]], pfunc: PythonTransformFunction) {
        val func = new TransformFunction((pfunc))
        jdstream.dstream.foreachRDD((rdd, time) => func(Some(rdd), time))
      }
    

    所以,终于要回答这个问题了,我们一开始看到的driver中的多个线程是怎么来的

    1. python调用foreachRDD提供一个TranformFunction给scala端
    2. scala端调用自己的foreachRDD进行正常的spark streaming作业
    3. 由于我们开启了spark.streaming.concurrentJobs,多个batch可以同时运行,这在scala端是通过线程池来进行的,每个batch都需要回调Python中的TranformFunction,而按照我们之前介绍的Py4J线程模型,多个并发的回调会发现没有可用的socket连接而生成新的,而在CallbackServer(Python)这端,每个新连接都会创建一个新线程来处理。这样就出现了driver的Python进程中出现多个线程的现象。

    参考阅读

    1. MySQLdb1中的死锁issue
    2. queue-queue-vs-collections-deque
    3. Python中的deque是线程安全的吗?
    4. py4j线程模型官方文档

    相关文章

      网友评论

      本文标题:pyspark与py4j线程模型简析

      本文链接:https://www.haomeiwen.com/subject/ukilattx.html