Spark 实现MySQL update操作

作者: BIGUFO | 来源:发表于2018-05-19 13:14 被阅读25次

    背景

    目前 spark 对 MySQL 的操作只有 Append,Overwrite,ErrorIfExists,Ignore几种表级别的模式,有时我们需要对表进行行级别的操作,比如update。即我们需要构造这样的语句出来:insert into tb (id,name,age) values (?,?,?) on duplicate key update id=?,name =? ,age=?;

    需求:我们的目的是既不影响以前写的代码,又不引入新的API,只需新加一个配置如:savemode=update这样的形式来实现。

    实践

    要满足以上需求,肯定是要改源码的,首先创建自己的saveMode,只是新加了一个Update而已:

    public enum I4SaveMode {
        Append,
        Overwrite,
        ErrorIfExists,
        Ignore,
        Update
    }
    

    JDBC数据源的相关实现主要在JdbcRelationProvider里,我们需要关注的是createRelation方法,我们可以在此方法里,把SaveMode改成我们自己的mode,并把mode带到saveTable方法里,所以改造后的方法如下(改了的地方都有注释):

       override def createRelation(
                                       sqlContext: SQLContext,
                                       mode: SaveMode,
                                       parameters: Map[String, String],
                                       df: DataFrame): BaseRelation = {
            val options = new JDBCOptions(parameters)
            val isCaseSensitive = sqlContext.conf.caseSensitiveAnalysis
            // 替换成自己的saveMode
            var saveMode = mode match {
                    case SaveMode.Overwrite => I4SaveMode.Overwrite
                    case SaveMode.Append => I4SaveMode.Append
                    case SaveMode.ErrorIfExists => I4SaveMode.ErrorIfExists
                    case SaveMode.Ignore => I4SaveMode.Ignore
                }
            //重点在这里,检查是否有saveMode=update的参数,并设为对应的模式
            val parameterLower = parameters.map(kv => (kv._1.toLowerCase,kv._2))
            if(parameterLower.keySet.contains("savemode")){
                saveMode = if(parameterLower.get("savemode").get.equals("update")) I4SaveMode.Update else saveMode
            }
            val conn = JdbcUtils.createConnectionFactory(options)()
            try {
                val tableExists = JdbcUtils.tableExists(conn, options)
                if (tableExists) {
                    saveMode match {
                        case I4SaveMode.Overwrite =>
                            if (options.isTruncate && isCascadingTruncateTable(options.url) == Some(false)) {
                                // In this case, we should truncate table and then load.
                                truncateTable(conn, options.table)
                                val tableSchema = JdbcUtils.getSchemaOption(conn, options)
                                saveTable(df, tableSchema, isCaseSensitive, options, saveMode)
                            } else {
                            ......
        }
    

    接下来就是saveTable方法:

    def saveTable(
          df: DataFrame,
          tableSchema: Option[StructType],
          isCaseSensitive: Boolean,
          options: JDBCOptions,
          mode: I4SaveMode): Unit = { 
        ......
        val insertStmt = getInsertStatement(table, rddSchema, tableSchema, isCaseSensitive, dialect)
        .....
        repartitionedDF.foreachPartition(iterator => savePartition(
          getConnection, table, iterator, rddSchema, insertStmt, batchSize, dialect, isolationLevel)
        )
      }
    

    这里通过getInsertStatement方法构造sql语句,接着遍历每个分区进行对应的save操作,我们先看是构造语句是怎么改的(改了的地方都有注释):

    def getInsertStatement(
          table: String,
          rddSchema: StructType,
          tableSchema: Option[StructType],
          isCaseSensitive: Boolean,
          dialect: JdbcDialect,
          mode: I4SaveMode): String = {
        val columns = if (tableSchema.isEmpty) {
          rddSchema.fields.map(x => dialect.quoteIdentifier(x.name)).mkString(",")
        } else {
          val columnNameEquality = if (isCaseSensitive) {
            org.apache.spark.sql.catalyst.analysis.caseSensitiveResolution
          } else {
            org.apache.spark.sql.catalyst.analysis.caseInsensitiveResolution
          } 
          val tableColumnNames = tableSchema.get.fieldNames
          rddSchema.fields.map { col =>
            val normalizedName = tableColumnNames.find(f => columnNameEquality(f, col.name)).getOrElse {
              throw new AnalysisException(s"""Column "${col.name}" not found in schema $tableSchema""")
            }
            dialect.quoteIdentifier(normalizedName)
          }.mkString(",")
        } 
        val placeholders = rddSchema.fields.map(_ => "?").mkString(",")
        // s"INSERT INTO $table ($columns) VALUES ($placeholders)"
       //若为update模式需要单独构造
        mode match {
                case I4SaveMode.Update ⇒
                    val duplicateSetting = rddSchema.fields.map(x => dialect.quoteIdentifier(x.name)).map(name ⇒ s"$name=?").mkString(",")
                    s"INSERT INTO $table ($columns) VALUES ($placeholders) ON DUPLICATE KEY UPDATE $duplicateSetting"
                case _ ⇒ s"INSERT INTO $table ($columns) VALUES ($placeholders)"
            }
      }
    

    只需判断是否是update模式来构造对应的 sql语句,接着主要是看 savePartition 方法,看看具体是怎么保存的:

     def savePartition(
          getConnection: () => Connection,
          table: String,
          iterator: Iterator[Row],
          rddSchema: StructType,
          insertStmt: String,
          batchSize: Int,
          dialect: JdbcDialect,
          isolationLevel: Int): Iterator[Byte] = {
        val conn = getConnection()
        var committed = false
    
        var finalIsolationLevel = Connection.TRANSACTION_NONE
        if (isolationLevel != Connection.TRANSACTION_NONE) {
          try {
            val metadata = conn.getMetaData
            if (metadata.supportsTransactions()) {
              // Update to at least use the default isolation, if any transaction level
              // has been chosen and transactions are supported
              val defaultIsolation = metadata.getDefaultTransactionIsolation
              finalIsolationLevel = defaultIsolation
              if (metadata.supportsTransactionIsolationLevel(isolationLevel))  {
                // Finally update to actually requested level if possible
                finalIsolationLevel = isolationLevel
              } else {
                logWarning(s"Requested isolation level $isolationLevel is not supported; " +
                    s"falling back to default isolation level $defaultIsolation")
              }
            } else {
              logWarning(s"Requested isolation level $isolationLevel, but transactions are unsupported")
            }
          } catch {
            case NonFatal(e) => logWarning("Exception while detecting transaction support", e)
          }
        }
        val supportsTransactions = finalIsolationLevel != Connection.TRANSACTION_NONE
    
        try {
          if (supportsTransactions) {
            conn.setAutoCommit(false) // Everything in the same db transaction.
            conn.setTransactionIsolation(finalIsolationLevel)
          }
          val stmt = conn.prepareStatement(insertStmt)
          val setters = rddSchema.fields.map(f => makeSetter(conn, dialect, f.dataType))
          val nullTypes = rddSchema.fields.map(f => getJdbcType(f.dataType, dialect).jdbcNullType)
          val numFields = rddSchema.fields.length
    
          try {
            var rowCount = 0
            while (iterator.hasNext) {
              val row = iterator.next()
              var i = 0
              while (i < numFields) {
                if (row.isNullAt(i)) {
                  stmt.setNull(i + 1, nullTypes(i))
                } else {
                  setters(i).apply(stmt, row, i)
                }
                i = i + 1
              }
              stmt.addBatch()
              rowCount += 1
              if (rowCount % batchSize == 0) {
                stmt.executeBatch()
                rowCount = 0
              }
            }
            if (rowCount > 0) {
              stmt.executeBatch()
            }
          } finally {
            stmt.close()
          }
          if (supportsTransactions) {
            conn.commit()
          }
          committed = true
          Iterator.empty
        } catch {
          case e: SQLException =>
            val cause = e.getNextException
            if (cause != null && e.getCause != cause) {
              // If there is no cause already, set 'next exception' as cause. If cause is null,
              // it *may* be because no cause was set yet
              if (e.getCause == null) {
                try {
                  e.initCause(cause)
                } catch {
                  // Or it may be null because the cause *was* explicitly initialized, to *null*,
                  // in which case this fails. There is no other way to detect it.
                  // addSuppressed in this case as well.
                  case _: IllegalStateException => e.addSuppressed(cause)
                }
              } else {
                e.addSuppressed(cause)
              }
            }
            throw e
        } finally {
          if (!committed) {
            // The stage must fail.  We got here through an exception path, so
            // let the exception through unless rollback() or close() want to
            // tell the user about another problem.
            if (supportsTransactions) {
              conn.rollback()
            }
            conn.close()
          } else {
            // The stage must succeed.  We cannot propagate any exception close() might throw.
            try {
              conn.close()
            } catch {
              case e: Exception => logWarning("Transaction succeeded, but closing failed", e)
            }
          }
        }
      }
    

    大体思想就是在迭代该分区数据进行插入之前就先根据数据的schema设置好了插入模板setters,迭代的时候只需将此模板应用到每一行数据上就行了,避免了每一行都需要去判断数据类型。
    在非update的情况下:insert into tb (id,name,age) values (?,?,?)
    在update情况下:insert into tb (id,name,age) values (?,?,?) on duplicate key update id=?,name =? ,age=?;
    即占位符多了一倍,在update模式下进行写入的时候需要向PreparedStatement多喂一遍数据。原本的makeSetter方法如下:

    private def makeSetter(
          conn: Connection,
          dialect: JdbcDialect,
          dataType: DataType): JDBCValueSetter = dataType match {
        case IntegerType =>
          (stmt: PreparedStatement, row: Row, pos: Int) =>
            stmt.setInt(pos + 1, row.getInt(pos))
        case LongType =>
          (stmt: PreparedStatement, row: Row, pos: Int) =>
            stmt.setLong(pos + 1, row.getLong(pos))
        ...
      }
    

    我们只需要再加一个相对位置参数offset来控制,即改造成:

    private def makeSetter(
           conn: Connection,
           dialect: JdbcDialect,
           dataType: DataType): JDBCValueSetter = dataType match {
         case IntegerType ⇒
            (stmt: PreparedStatement, row: Row, pos: Int, offset: Int) ⇒
                 stmt.setInt(pos + 1, row.getInt(pos - offset))
         case LongType ⇒
            (stmt: PreparedStatement, row: Row, pos: Int, offset: Int) ⇒
                 stmt.setLong(pos + 1, row.getLong(pos - offset))
        ...
    

    在非update模式下offset就为0,在update模式下在没有超过numFields时offset为0,超过numFileds时offset为numFields。改造后的savePartition方法为:

    def savePartition(
                     getConnection: () => Connection,
                     table: String,
                     iterator: Iterator[Row],
                     rddSchema: StructType,
                     insertStmt: String,
                     batchSize: Int,
                     dialect: JdbcDialect,
                     isolationLevel: Int,
                     mode: I4SaveMode): Iterator[Byte] = {
        ...
        //判断是否为update
        val isUpdateMode = mode == I4SaveMode.Update
        val stmt = conn.prepareStatement(insertStmt)
        val setters: Array[JDBCValueSetter] = rddSchema.fields.map(f => makeSetter(conn, dialect, f.dataType))
        val nullTypes = rddSchema.fields.map(f => getJdbcType(f.dataType, dialect).jdbcNullType)
        val length = rddSchema.fields.length
        // update模式下占位符是2倍
        val numFields = if (isUpdateMode) length * 2 else length
        val midField = numFields / 2
        try {
            var rowCount = 0
            while (iterator.hasNext) {
                val row = iterator.next()
                var i = 0
                while (i < numFields) {
                    if (isUpdateMode) {
                        // update模式下未超过字段长度,offset为0
                        i < midField match {
                            case true ?
                                if (row.isNullAt(i)) {
                                    stmt.setNull(i + 1, nullTypes(i))
                                } else {
                                    setters(i).apply(stmt, row, i, 0)
                                }
                            // update模式下超过字段长度,offset为midField,即字段长度
                            case false ?
                                if (row.isNullAt(i - midField)) {
                                    stmt.setNull(i + 1, nullTypes(i - midField))
                                } else {
                                    setters(i - midField).apply(stmt, row, i, midField)
                                }
                        }
                    
                    } else {
                        if (row.isNullAt(i)) {
                            stmt.setNull(i + 1, nullTypes(i))
                        } else {
                            setters(i).apply(stmt, row, i, 0)
                        }
                    }
                    i = i + 1
                }
              ...
    

    改造好源码后,需要重新编译打包,替换掉线上对应的jar即可。其实这里有个捷径,自己创建相同的包名,改好源码后打成jar包,把该jar里面的class文件替换掉线上jar里面对应的那些class文件就可以了。

    如何使用

    若需要使用到update模式:

    df.write.option("saveMode","update").jdbc(...)
    

    参考

    https://blog.csdn.net/cjuexuan/article/details/52333970

    我的GitHub

    相关文章

      网友评论

      • 宇智波_佐助:你好 还有就是关于 offset这一块 看到后边是关于 时间转化的 也需要pos-offset吗?
      • 宇智波_佐助:你好,JdbcUtils.getSchemaOption(conn, jdbcOptions) 这个getSchemaOption方法是自己重写的吗?在JdbcUtils类中并没有这个方法
        宇智波_佐助:@BIGUFO 嗯 我用的是spark2.1.0这个版本 ,谢谢了
        BIGUFO:我这用的spark2.3是有这个方法的,不管哪个版本按着这个模式改一下就行。

      本文标题:Spark 实现MySQL update操作

      本文链接:https://www.haomeiwen.com/subject/cekidftx.html