Spark JDBC系列--源码简析

作者: wuli_小博 | 来源:发表于2017-12-04 17:41 被阅读110次

    本文旨在简析 Spark 读取数据库的一些关键源码

    Spark如何读取数据库数据

    像其他的数据映射框架一样(如hibernate,mybatis等),spark如果想读取数据库数据,也绕不开JDBC链接,毕竟这是代码与数据库“交流”的官方途径。spark如果想快速读取数据库中的数据,需要解决的事情包括但不限于:

    • 分布式读取
    • 原始数据到RDD/DataFrame的映射

    所以这篇小文主要围绕这两个方面做下源码的简析

    关于spark操作数据库API,可以参考这篇文档:Spark JDBC系列--取数的四种方式

    源码简析

    1.JDBC API公共入口

    入口源码:

    org.apache.spark.sql.DataFrameReader
    ...
    private def jdbc(
      url: String,
      table: String,
      parts: Array[Partition],
      connectionProperties: Properties): DataFrame = {
        val props = new Properties()
        extraOptions.foreach { case (key, value) =>
          props.put(key, value)
        }
        // connectionProperties should override settings in extraOptions
        props.putAll(connectionProperties)
        //关键点
        val relation = JDBCRelation(url, table, parts, props)(sparkSession)
        //逻辑分区的创建,action后会触发读取
        sparkSession.baseRelationToDataFrame(relation)
    }
    

    通过观察源码可知,四种取数API的参数虽然略有不同,但最终都转换成了一个Array[Partition],即分区条件数组。

    2.指定column的取数API分区原理简析

    此处列举提供long型column的分区模式的API的分区原理,先看源码:

    def columnPartition(partitioning: JDBCPartitioningInfo): Array[Partition] = {
        if (partitioning == null || partitioning.numPartitions <= 1 ||
          partitioning.lowerBound == partitioning.upperBound) {
          //单分区模式会进入此条件
          return Array[Partition](JDBCPartition(null, 0))
        }
        
        //合法性校验
        val lowerBound = partitioning.lowerBound
        val upperBound = partitioning.upperBound
        ....
          
        //分区调整
        val numPartitions =
          if ((upperBound - lowerBound) >= partitioning.numPartitions) {
            partitioning.numPartitions
          } else {
            upperBound - lowerBound
          }
          
        //计算步长
        val stride: Long = upperBound / numPartitions - lowerBound / numPartitions
        val column = partitioning.column
        var i: Int = 0
        var currentValue: Long = lowerBound
        var ans = new ArrayBuffer[Partition]()
        
        //根据步长,根据提供的最大、最小值做步长累计,确定边界后组装where查询条件
        while (i < numPartitions) {
          //注意此处,会存在单边限制条件的情况,如:JDBCPartition(id >= 901,9)
          val lBound = if (i != 0) s"$column >= $currentValue" else null
          currentValue += stride
          val uBound = if (i != numPartitions - 1) s"$column < $currentValue" else null
          val whereClause =
            if (uBound == null) {
              lBound
            } else if (lBound == null) {
              s"$uBound or $column is null"
            } else {
              s"$lBound AND $uBound"
            }
          ans += JDBCPartition(whereClause, i)
          i = i + 1
        }
        ans.toArray
      }
    

    测试代码与分区结果如下:

    入参为:
    lowerBound=1, upperBound=1000, numPartitions=10
    
    对应分区数组为:
    JDBCPartition(id < 101 or id is null,0), 
    JDBCPartition(id >= 101 AND id < 201,1), 
    JDBCPartition(id >= 201 AND id < 301,2), 
    JDBCPartition(id >= 301 AND id < 401,3), 
    JDBCPartition(id >= 401 AND id < 501,4), 
    JDBCPartition(id >= 501 AND id < 601,5), 
    JDBCPartition(id >= 601 AND id < 701,6), 
    JDBCPartition(id >= 701 AND id < 801,7), 
    JDBCPartition(id >= 801 AND id < 901,8), 
    JDBCPartition(id >= 901,9)
    

    这种使用方式存在误用场景,即通过指定一段ID的最大最小值(而非整张表真正的最大最小值去取数据),则依然会取出全表数据,且发生数据倾斜,原因就在于第一个分区和最后一个分区的where条件处理,所以如果需要指定范围或更多条件,建议使用支持自定义分区条件的API。

    3.数据结果映射

    函数:

    org.apache.spark.sql.execution.datasources.jdbc.JDBCRelation
    //获取dataframe的schema,即对数据库的字段类型和spark的数据类型做映射
    override val schema: StructType = JDBCRDD.resolveTable(url, table, properties)
    
    //具体实现
    org.apache.spark.sql.execution.datasources.jdbc.JDBCRDD
    def resolveTable(url: String, table: String, properties: Properties): StructType = {
      //url中识别出需要使用的方言
     val dialect = JdbcDialects.get(url)
      val ncols = rsmd.getColumnCount
      val fields = new Array[StructField](ncols)
      var i = 0
      ....
      
      while (i < ncols) {
        val columnName = rsmd.getColumnLabel(i + 1)
        val dataType = rsmd.getColumnType(i + 1)
        val typeName = rsmd.getColumnTypeName(i + 1)
        val fieldSize = rsmd.getPrecision(i + 1)
        val fieldScale = rsmd.getScale(i + 1)
        ....
        
        //根据不同方言的约定做映射,未找到时使用默认映射规则
        val columnType =dialect.getCatalystType(dataType, typeName, fieldSize, metadata).getOrElse(
            getCatalystType(dataType, fieldSize, fieldScale, isSigned))
        fields(i) = StructField(columnName, columnType, nullable, metadata.build())
        i = i + 1
      }
      return new StructType(fields)
      
      字段映射的默认配置例举:
      val answer = sqlType match {
      ....   
      case java.sql.Types.BLOB          => BinaryType
      case java.sql.Types.BOOLEAN       => BooleanType
      case java.sql.Types.CHAR          => StringType
      case java.sql.Types.CLOB          => StringType
      case java.sql.Types.DATALINK      => null
      case java.sql.Types.DATE          => DateType
      case java.sql.Types.DECIMAL
        if precision != 0 || scale != 0 => DecimalType.bounded(precision, scale)
      case java.sql.Types.DECIMAL       => DecimalType.SYSTEM_DEFAULT
      case java.sql.Types.DISTINCT      => null
      case java.sql.Types.DOUBLE        => DoubleType
      case java.sql.Types.FLOAT         => FloatType
      ....
    }
    

    此处例举MySQL的方言实现:

    所有的方言实现都此包下:org.apache.spark.sql.jdbc.*,实现请自行参考。
    
    MySQL方言:
    private case object MySQLDialect extends JdbcDialect {
    
      override def canHandle(url : String): Boolean = url.startsWith("jdbc:mysql")
    
      override def getCatalystType(
          sqlType: Int, typeName: String, size: Int, md: MetadataBuilder): Option[DataType] = {
        //关键实现
        if (sqlType == Types.VARBINARY && typeName.equals("BIT") && size != 1) {
          // This could instead be a BinaryType if we'd rather return bit-vectors of up to 64 bits as
          // byte arrays instead of longs.
          md.putLong("binarylong", 1)
          Option(LongType)
        } else if (sqlType == Types.BIT && typeName.equals("TINYINT")) {
          Option(BooleanType)
        } else None
      }
      ....
    }
    

    从源码可以看出,MySQL只对bit和tinyint类型进行了约束,其他类型使用了spark的默认配置,所以在读取数据时,需要考虑spark中的方言映射,是否对已存在的数据造成影响,避免数据失真。
    此时 JDBCRelation 对象已经完成构造。

    4.RDD构造与逻辑分区生成

    根据之前生成的 JDBCRelation,sparkSession会把任务加入逻辑执行计划。当遇到action操作时,会转为物理执行计划,

    org.apache.spark.sql.SparkSession
    //逻辑执行计划构建,细节不写了,源码我也没怎么研究过
    def baseRelationToDataFrame(baseRelation: BaseRelation): DataFrame = {
    Dataset.ofRows(self, LogicalRelation(baseRelation))
    }
    
    org.apache.spark.sql.execution.datasources.DataSourceStrategy
    //物理执行计划
    object DataSourceStrategy extends Strategy with Logging {
      def apply(plan: LogicalPlan): Seq[execution.SparkPlan] = plan match {
        case PhysicalOperation.....
    
        //JDBCRelation继承了PrunedFilteredScan,进入此case分支,并调用buildScan方法
        case PhysicalOperation(projects, filters, l @ LogicalRelation(t: PrunedFilteredScan, _, _)) =>
          pruneFilterProject(
            l,
            projects,
            filters,
            (a, f) => toCatalystRDD(l, a, t.buildScan(a.map(_.name).toArray, f))) :: Nil
    
        case PhysicalOperation..... 
      }    
    

    JDBCRelation 的 buildScan 方法执行时,会调用JDBCRDD的 scanTable 方法新建 RDD,其中计算前加入的 filter 条件,会合并到JDBC查询where条件中,使用AND连接:

    private[jdbc] class JDBCRDD(
        sc: SparkContext,
        getConnection: () => Connection,
        schema: StructType,
        fqTable: String,
        columns: Array[String],
        filters: Array[Filter],
        partitions: Array[Partition],
        url: String,
        properties: Properties)
      extends RDD[InternalRow](sc, Nil) {
    
      override def getPartitions: Array[Partition] = partitions
      
      .....
        
      private def getWhereClause(part: JDBCPartition): String = {
        if (part.whereClause != null && filterWhereClause.length > 0) {
          "WHERE " + s"($filterWhereClause)" + " AND " + s"(${part.whereClause})"
        } else if (part.whereClause != null) {
          "WHERE " + part.whereClause
        } else if (filterWhereClause.length > 0) {
          "WHERE " + filterWhereClause
        } else {
          ""
        }
      }
      
      //compute方法为action触发时,执行的SQL语句,并对结果按之前的约定做数据映射
      override def compute(thePart: Partition, context: TaskContext): Iterator[InternalRow] =
        new Iterator[InternalRow] {
        。。。。
        //实现细节不再展开,主要是JDBC查询操作和数据类型映射
    }
    

    filter条件使用示例:

    val url = "jdbc:mysql://mysqlHost:3306/database"
    val tableName = "table"
    val columnName = "id"
    val lowerBound = getMinId()
    val upperBound = getMaxId()
    val numPartitions = 200
    
    // 设置连接用户&密码
    val prop = new java.util.Properties
    prop.setProperty("user","username")
    prop.setProperty("password","pwd")
    
    // 对mysql数据进行过滤
    val jdbcDF = sqlContext.read.jdbc(url,tableName, columnName, lowerBound, upperBound,prop).where("date='2017-11-30'").filter("name is not null")
    

    where 和 filter 是等价的,过滤条件将在 where 语句中生效,多个条件会用And进行拼接。

    结语

    读取数据库数据时,可以到对应的源码中,debug分析。

    相关文章

      网友评论

        本文标题:Spark JDBC系列--源码简析

        本文链接:https://www.haomeiwen.com/subject/vkvbbxtx.html