美文网首页
Spark Streaming提高写数据库的效率

Spark Streaming提高写数据库的效率

作者: 郭寻抚 | 来源:发表于2016-11-01 09:33 被阅读792次

    1. 前言

    这是一篇挂羊头卖狗肉的文章,事实上,本文要描述的内容,和Spark Streaming没有什么关系。

    在上一篇文章http://www.jianshu.com/p/a73c0c95d2fe 我们写了如何通过Spark Streaming向数据库中插入数据。可能你已经发现了,数据是逐条插入数据库的,效率底下。那么如何提高插入数据库的效率呢?

    数据库写是个IO任务,并行不一定能够加速写入数据库的速度。我们主要说下批量提交和Bulk Copy Insert的方式。

    2.批量提交

    批量提交,就是JDBC Statment的executeBatch,直接看代码吧。

    /**
      * 从Kafka中读取数据,并把数据成批写入数据库
      */
    object KafkaToDB {
    
      val logger = LoggerFactory.getLogger(this.getClass)
    
      def main(args: Array[String]): Unit = {
        // 参数校验
        if (args.length < 2) {
          System.err.println(
            s"""
               |Usage: KafkaToDB <brokers> <topics>
               |  <brokers> is a list of one or more Kafka brokers
               |  <topics> is a list of one or more kafka topics to consume from
               |""".stripMargin)
          System.exit(1)
        }
    
        // 处理参数
        val Array(brokers, topics) = args
        // topic以“,”分割
        val topicSet: Set[String] = topics.split(",").toSet
        val kafkaParams: Map[String, Object] = Map[String, Object](
          "bootstrap.servers" -> brokers,
          "key.deserializer" -> classOf[StringDeserializer],
          "value.deserializer" -> classOf[StringDeserializer],
          "group.id" -> "example",
          "auto.offset.reset" -> "latest",
          "enable.auto.commit" -> (false: java.lang.Boolean)
        )
    
        // 创建上下文,以每1秒间隔的数据作为一批
        val sparkConf = new SparkConf().setAppName("KafkaToDB")
        val streamingContext = new StreamingContext(sparkConf, Seconds(2))
    
        // 1.创建输入流,获取数据。流操作基于DStream,InputDStream继承于DStream
        val stream = KafkaUtils.createDirectStream[String, String](
          streamingContext,
          PreferConsistent,
          Subscribe[String, String](topicSet, kafkaParams)
        )
    
        // 2. DStream上的转换操作
        // 取消息中的value数据,以英文逗号分割,并转成Tuple3
        val values = stream.map(_.value.split(","))
          .filter(x => x.length == 3)
          .map(x => new Tuple3[String, String, String](x(0), x(1), x(2)))
    
    
        // 输入前10条到控制台,方便调试
        values.print()
    
        // 3.同foreachRDD保存到数据库
        val sql = "insert into kafka_message(timeseq,timeseq2, thread, message) values (?,?,?,?)"
        values.foreachRDD(rdd => {
          val count = rdd.count()
          println("-----------------count:" + count)
          if (count > 0) {
            rdd.foreachPartition(partitionOfRecords => {
              val conn = ConnectionPool.getConnection.orNull
              if (conn != null) {
                val ps = conn.prepareStatement(sql)
                try{
                  // 关闭自动执提交
                  conn.setAutoCommit(false)
                  partitionOfRecords.foreach(data => {
                    ps.setString(1, data._1)
                    ps.setString(2,System.currentTimeMillis().toString)
                    ps.setString(3, data._2)
                    ps.setString(4, data._3)
                    ps.addBatch()
                  })
                  ps.executeBatch()
                  conn.commit()
                } catch {
                  case e: Exception =>
                    logger.error("Error in execution of insert. " + e.getMessage)
                }finally {
                  ps.close()
                  ConnectionPool.closeConnection(conn)
                }
              }
            })
          }
        })
    
        streamingContext.start() // 启动计算
        streamingContext.awaitTermination() // 等待中断结束计算
    
      }
    }
    

    3. Bulk Copy Insert

    我们使用的是PostgreSQL,其数据库JDBC驱动程序提供了Copy Insert的API,其主要过程是:

    • 1.获取数据库连接
    • 2.创建CopyManager
    • 3.把Spark Streaming中的流数据封装成InputStream
    • 4.执行CopyInsert
    import java.sql.Connection
    
    import org.apache.kafka.common.serialization.StringDeserializer
    import org.apache.spark.SparkConf
    import org.apache.spark.streaming.kafka010.ConsumerStrategies._
    import org.apache.spark.streaming.kafka010.KafkaUtils
    import org.apache.spark.streaming.kafka010.LocationStrategies._
    import org.apache.spark.streaming.{Seconds, StreamingContext}
    import org.postgresql.copy.CopyManager
    import org.postgresql.core.BaseConnection
    import org.slf4j.LoggerFactory
    
    object CopyInsert {
    
      val logger = LoggerFactory.getLogger(this.getClass)
    
      def main(args: Array[String]): Unit = {
        // 参数校验
        if (args.length < 4) {
          System.err.println(
            s"""
               |Usage: CopyInsert <brokers> <topics> <duration> <batchsize>
               |  <brokers> is a list of one or more Kafka brokers
               |  <topics> is a list of one or more kafka topics to consume from
               |""".stripMargin)
          System.exit(1)
        }
    
        // 处理参数
        val Array(brokers, topics,duration,batchsize) = args
        // topic以“,”分割
        val topicSet: Set[String] = topics.split(",").toSet
        val kafkaParams: Map[String, Object] = Map[String, Object](
          "bootstrap.servers" -> brokers,
          "key.deserializer" -> classOf[StringDeserializer],
          "value.deserializer" -> classOf[StringDeserializer],
          "group.id" -> "example",
          "auto.offset.reset" -> "latest",
          "enable.auto.commit" -> (false: java.lang.Boolean)
        )
    
        // 创建上下文,以每1秒间隔的数据作为一批
        val sparkConf = new SparkConf().setAppName("CopyInsertIntoPostgreSQL")
        val streamingContext = new StreamingContext(sparkConf, Seconds(duration.toInt))
    
        // 1.创建输入流,获取数据。流操作基于DStream,InputDStream继承于DStream
        val stream = KafkaUtils.createDirectStream[String, String](
          streamingContext,
          PreferConsistent,
          Subscribe[String, String](topicSet, kafkaParams)
        )
    
        // 2. DStream上的转换操作
        // 取消息中的value数据,以英文逗号分割,并转成Tuple3
        val values = stream.map(_.value.split(","))
          .filter(x => x.length == 3)
          .map(x => new Tuple3[String, String, String](x(0), x(1), x(2)))
    
    
        // 输入前10条到控制台,方便调试
        values.print()
    
        // 3.同foreachRDD保存到数据库
        // http://rostislav-matl.blogspot.jp/2011/08/fast-inserts-to-postgresql-with-jdbc.html
        values.foreachRDD(rdd => {
          val count = rdd.count()
          println("-----------------count:" + count)
          if (count > 0) {
            rdd.foreachPartition(partitionOfRecords => {
              val start = System.currentTimeMillis()
              val conn: Connection = ConnectionPool.getConnection.orNull
              if (conn != null) {
                val batch = batchsize.toInt
                var counter: Int = 0
                val sb: StringBuilder = new StringBuilder()
                // 获取数据库连接
                val baseConnection = conn.getMetaData.getConnection.asInstanceOf[BaseConnection]
                // 创建CopyManager
                val cpManager: CopyManager = new CopyManager(baseConnection)
                partitionOfRecords.foreach(record => {
                  counter += 1
                  sb.append(record._1).append(",")
                    .append(System.currentTimeMillis()).append(",")
                    .append(record._2).append(",")
                    .append(record._3).append("\n")
                  if (counter == batch) {
                    // 构建输入流
                    val in: InputStream = new ByteArrayInputStream(sb.toString().getBytes())
                    // 执行copyin
                    cpManager.copyIn("COPY kafka_message FROM STDIN WITH CSV", in)
                    println("-----------------batch---------------: " + batch)
                    counter = 0
                    sb.delete(0, sb.length)
                    closeInputStream(in)
                  }
                })
                val lastIn: InputStream = new ByteArrayInputStream(sb.toString().getBytes())
                cpManager.copyIn("COPY kafka_message2 FROM STDIN WITH CSV", lastIn)
                sb.delete(0, sb.length)
                counter = 0
                closeInputStream(lastIn)
                val end = System.currentTimeMillis()
                println("-----------------duration---------------ms :" + (end - start))
              }
            })
    
          }
        })
    
        streamingContext.start() // 启动计算
        streamingContext.awaitTermination() // 等待中断结束计算
     }
    
     def closeInputStream(in: InputStream): Unit ={
       try{
           in.close()
        }catch{
         case e: IOException =>
           logger.error("Error on close InputStream. " + e.getMessage)
          }
      }
        
    }
    

    其它数据库应该也有bulk load的方式,例如MySQL,com.mysql.jdbc.Statment中有setLocalInfileInputStream方法,功能应该和上述的Copy Insert类似,但我还没有写例子验证。文档里有如下的描述,供参考。原文地址

    Sets an InputStream instance that will be used to send data to the MySQL server for a "LOAD DATA LOCAL INFILE" statement rather than a FileInputStream or URLInputStream that represents the path given as an argument to the statement.

    (完)

    相关文章

      网友评论

          本文标题:Spark Streaming提高写数据库的效率

          本文链接:https://www.haomeiwen.com/subject/fdcmuttx.html