实现自己的 JDBC 驱动

作者: 何晓杰Dev | 来源:发表于2020-06-08 11:57 被阅读0次

    本文可能是全网最全面最详细的 JDBC 驱动开发教程,转载请注明出处

    要自己实现一个 JDBC 驱动无疑是较为困难的,在此之前我查阅了很多资料,也查到了许多正在提问的帖子,最后都是无疾而终,提问者没有得到想要的回答。而我由于刚好需要做这方面的开发,也只能硬啃 JDBC 驱动的文档了,写这篇算是一个资料的整理,同时我也已经完成了一个通用的 JDBC 生成工具,文末可以下载。


    首先还是来读文档,JDBC 的文档可谓是非常详细了,直接翻到 6.3 JDBC 4.0 API Compliance,它很明确的告诉了你应该做什么:

    1. 支持自动加载的,继承自 java.sql.Driver 的类
    2. 支持“仅向前”的结果集
    3. 支持“只读的”结果集并发类型
    4. 支持批量更新
    5. 实现以下接口
      1) java.sql.Driver
      2) java.sql.DatabaseMetaData 
      3) java.sql.ParameterMetaData 
      4) java.sql.ResultSetMetaData 
      5) java.sql.Wrapper
      6) javax.sql.DataSource
      7) java.sql.Connection
      8) java.sql.Statement
      9) java.sql.PreparedStatement
      10) java.sql.CallableStatement
      11) java.sql.ResultSet
    

    好的,看到这里我们会遇到官方文档挖的第一个坑,其实并非这么多东西都要实现的,实际开发中,只实现以下内容也可以:

      1) java.sql.Driver
      2) java.sql.DatabaseMetaData 
      3) java.sql.ResultSetMetaData 
      4) java.sql.Connection
      5) java.sql.Statement
      6) java.sql.PreparedStatement
      7) java.sql.ResultSet
    

    只要有这些东西,就足以支撑起一个完整的 JDBC 驱动。

    那么接下来就是实现了,不得不说,整个 JDBC 的协议真的又臭又长,所有的类加起来一共要实现 584 个接口函数,而且里面有一大半是不起任何作用的。

    限于篇幅,在这里我不可能把所有的函数予以列出,只看最重要的那些,其他的请大家自行发挥。


    那先来看 Driver:

    class MyDriver: Driver {
    
        companion object {
            init {
                try {
                    DriverManager.registerDriver(MyDriver())
                } catch (e: Exception) {
                    throw RuntimeException("Can't register $DRIVER_NAME", e)
                }
            }
        }
    
        override fun acceptsURL(url: String) = url.toLowerCase().startsWith(JDBC_URL)
    
        override fun connect(url: String, info: Properties?): Connection {
            if (!acceptsURL(url)) throw SQLException("Invalid URL: $url")
            val props = MyDriverUtil.parseMergeProperties(url.replace("jdbc:", ""), info)
            return MyConnection(props)
        }
    
        ... ...
    }
    

    使用 companion objectinit 块来完成驱动的自动加载,需要特别注意的是,千万不要不带 companion object,如果不带的话,init 块实质上是类的构造函数,而不是静态初始化块。

    acceptsURL 方法指出了什么样的 URL 可以被驱动接受,比如说我们经常在使用 mysql 驱动时看到其 JDBC URL 为 jdbc:mysql://,就是在 acceptsURL 里接受了这样的前缀。

    connect 方法用于获取一个连接,它是在 DriverManager.getConnection 时自动被调用的,返回一个非空的 Connection 对象。

    这里有个函数,MyDriverUtil.parseMergeProperties 用于将 url 的参数和 info: Properties 所携带的参数进行拼接,该函数实现如下:

    fun parseMergeProperties(url: String, prop: Properties?) = mutableMapOf<String, String>().apply {
        val uri = URI(url)
        this[PROP_HOST] = uri.host
        this[PROP_PORT] = (if (uri.port == -1) DEFAULT_PORT else uri.port).toString()
        this[PROP_PATH] = uri.path.replaceFirst("/", "")
        if (uri.query != null) {
            this += uri.query.split("&").map { p -> p.split("=").let { i -> Pair(i[0], i[1]) } }.toMap()
        }
        if (prop != null) {
            this += prop.map { e -> Pair(e.key.toString(), e.value.toString()) }.toMap()
        }
    }.toMap()
    

    在这段代码中出现了 MyConnection 这个对象,我们下面就来看看如何实现它。


    实现 Connection

    class MyConnection(props: Map<String, String>) : Connection {
    
        val io = MyIO(props)
        private var isClosed = false
        private val autoCommit = true
    
        override fun prepareStatement(sql: String) = MyPreparedStatement(sql, this)
        override fun prepareStatement(sql: String, resultSetType: Int, resultSetConcurrency: Int) = MyPreparedStatement(sql, this)
        override fun prepareStatement(sql: String, resultSetType: Int, resultSetConcurrency: Int, resultSetHoldability: Int) = MyPreparedStatement(sql, this)
        override fun prepareStatement(sql: String, autoGeneratedKeys: Int) = MyPreparedStatement(sql, this)
        override fun prepareStatement(sql: String, columnIndexes: IntArray?) = MyPreparedStatement(sql, this)
        override fun prepareStatement(sql: String, columnNames: Array<out String>?) = MyPreparedStatement(sql, this)
        override fun getAutoCommit() = autoCommit
        override fun getWarnings(): SQLWarning? = null
        override fun getCatalog(): String? {
            checkConnection()
            return null
        }
        override fun isValid(timeout: Int) = isClosed
        override fun close() {
            isClosed = true
        }
        override fun isClosed() = isClosed
        override fun isReadOnly() = false
        override fun createStatement() = MyStatement(this)
        override fun createStatement(resultSetType: Int, resultSetConcurrency: Int) = MyStatement(this)
        override fun createStatement(resultSetType: Int, resultSetConcurrency: Int, resultSetHoldability: Int) = MyStatement(this)
        override fun getMetaData() = MyDatabaseMetaData()
        override fun getTransactionIsolation() = Connection.TRANSACTION_NONE
        private fun checkConnection() {
            if (isClosed()) throw SQLException("Connection is closed")
        }
    
        ... ...
    
    }
    

    这里需要实现的东西就比较多了,需要实现所有返回 StatementPreparedStatement 的函数,也需要实现 isClosedclose 这类改变或检查状态的函数。

    这里有一个 MyIO 的对象,是一个自定义的类,用于完成真正的数据获取工作,在一会的代码中就会用到它了。


    在实现 Statement 之前,要先做一点准备工作,有一些公用的东西需要被抽象出来。

    abstract class MyAbsStatement {
        internal var isClosed = false
        internal var connection: MyConnection
        internal lateinit var resultSet: ResultSet
        protected lateinit var sql: String
    
        constructor(sql: String, conn: MyConnection) {
            this.sql = sql
            this.connection = conn
        }
        constructor(conn: MyConnection) {
            this.connection = conn
        }
    
        open fun executeForResultSet(sql: String): Boolean {
            if (isClosed) throw SQLException("This statement is closed.")
            try {
                resultSet = connection.io.internalExecuteQuery(sql)
                return true
            } catch (th: Throwable) {
                throw SQLException(th)
            }
        }
    
        open fun executeForResult(sql: String): Int {
            if (isClosed) throw SQLException("This statement is closed.")
            try {
                return connection.io.internalExecuteUpdate(sql)
            } catch (th: Throwable) {
                throw SQLException(th)
            }
        }
    }
    

    有了这个类之后,我们可以继承它,并且实现 Statement 接口:

    class MyStatement(conn: MyConnection) : MyAbsStatement(conn), Statement {
    
        private val batchOps = mutableListOf<String>()
    
        override fun execute(sql: String, autoGeneratedKeys: Int) = execute(sql)
        override fun execute(sql: String, columnIndexes: IntArray?) = execute(sql)
        override fun execute(sql: String, columnNames: Array<out String>?) = execute(sql)
    
        override fun clearBatch() {
            batchOps.clear()
        }
    
        override fun getResultSetType() = ResultSet.TYPE_FORWARD_ONLY
        override fun isCloseOnCompletion() = false
        override fun <T : Any> unwrap(iface: Class<T>): T? = null
        override fun getMaxRows() = 0
        override fun getConnection() = this.connection
        override fun getWarnings(): SQLWarning? = null
    
        override fun executeQuery(sql: String): ResultSet {
            this.execute(sql)
            return this.getResultSet()
        }
    
        override fun close() {
            isClosed = true
        }
    
        override fun isClosed() = this.isClosed
        override fun getMaxFieldSize() = 0
        override fun isWrapperFor(iface: Class<*>) = false
        override fun getUpdateCount() = -1
        override fun getFetchSize() = 0
    
        override fun executeBatch() = IntArray(batchOps.size).apply {
            this@MyStatement.batchOps.forEachIndexed { index, sql ->
                try {
                    this@MyStatement.execute(sql)
                    this[index] = SUCCESS_NO_INFO
                } catch (th: Throwable) {
                    throw BatchUpdateException(th)
                }
            }
        }
    
        override fun getQueryTimeout() = 0
        override fun isPoolable() = false
    
        override fun addBatch(sql: String) {
            batchOps.add(sql)
        }
    
        override fun getGeneratedKeys(): ResultSet? = null
        override fun getResultSetConcurrency() = ResultSet.CONCUR_READ_ONLY
        override fun getResultSet() = this.resultSet
        override fun execute(sql: String) = executeForResultSet(sql)
        override fun executeUpdate(sql: String) = executeForResult(sql)
        override fun executeUpdate(sql: String, autoGeneratedKeys: Int) = executeUpdate(sql)
        override fun executeUpdate(sql: String, columnIndexes: IntArray?) = executeUpdate(sql)
        override fun executeUpdate(sql: String, columnNames: Array<out String>?) = executeUpdate(sql)
        override fun getFetchDirection() = 0
        override fun getResultSetHoldability() = ResultSet.CLOSE_CURSORS_AT_COMMIT
        override fun getMoreResults() = false
        override fun getMoreResults(current: Int) = false
    
        ... ...
    }
    

    这里主要实现 execute 相关的方法,这个时候定义在抽象类里的 executeForResultexecuteForResultSet 就有了用武之地,它们可以将所有的请求一并接管起来。

    同样需要注意的,是在 JDBC 文档内所述的,必须支持批量更新,在 Statement 内需要予以支持。


    Statement 类似的,下面来实现 PreparedStatement,与 Statement 不同的地方在于,PreparedStatement 需要用户自己处理替换问号占位符的操作。

    先给出这个操作的代码:

    private fun replaceSQL() {
        var idx = 1
        while (sql.indexOf("?") > 1) {
            try {
                val p = parameters[idx]
                sql = sql.replaceFirst("?", if (p == null) "null" else "'$p'")
            } catch (e: IndexOutOfBoundsException) {
                throw SQLException("Can't find defined parameter for position: $idx")
            }
            idx++
        }
    }
    

    然后来看看实现 PreparedStatement 需要做些什么:

    
    class MyPreparedStatement(sql: String, conn: MyConnection) : MyAbsStatement(sql, conn), PreparedStatement {
    
        private val parameters = mutableMapOf<Int, String?>()
    
        override fun execute(): Boolean {
            replaceSQL()
            return super.executeForResultSet(sql)
        }
    
        override fun execute(sql: String): Boolean {
            this.sql = sql
            return this.execute()
        }
    
        override fun execute(sql: String, autoGeneratedKeys: Int) = execute(sql)
        override fun execute(sql: String, columnIndexes: IntArray?) = execute(sql)
        override fun execute(sql: String, columnNames: Array<out String>?) = execute(sql)
        override fun getResultSetType() = ResultSet.TYPE_FORWARD_ONLY
    
        override fun clearParameters() {
            parameters.clear()
        }
    
        override fun getConnection() = this.connection
        override fun getWarnings(): SQLWarning? = null
        override fun getParameterMetaData(): ParameterMetaData? = null
    
        override fun executeQuery(): ResultSet {
            this.execute()
            return this.resultSet
        }
    
        override fun executeQuery(sql: String): ResultSet {
            execute(sql)
            return this.resultSet
        }
    
        override fun executeUpdate(): Int {
            replaceSQL()
            return executeForResult(sql)
        }
    
        override fun executeUpdate(sql: String): Int {
            replaceSQL()
            return executeForResult(sql)
        }
    
        override fun executeUpdate(sql: String, autoGeneratedKeys: Int) = executeUpdate(sql)
        override fun executeUpdate(sql: String, columnIndexes: IntArray?) = executeUpdate(sql)
        override fun executeUpdate(sql: String, columnNames: Array<out String>?) = executeUpdate(sql)
    
        override fun close() {
            isClosed = true
        }
    
        override fun isCloseOnCompletion() = false
        override fun getMaxRows() = 0
        override fun isClosed() = isClosed
        override fun getMaxFieldSize() = 0
        override fun getUpdateCount() = 0
        override fun getFetchSize() = 0
        override fun executeBatch(): IntArray? = null
        override fun getQueryTimeout() = 0
        override fun isPoolable() = false
        override fun getGeneratedKeys(): ResultSet? = null
        override fun getResultSetConcurrency() = ResultSet.CONCUR_READ_ONLY
        override fun getResultSet() = this.resultSet
        override fun getMetaData() = MyResultSetMetaData()
        override fun getFetchDirection() = ResultSet.FETCH_FORWARD
        override fun getResultSetHoldability() = ResultSet.CLOSE_CURSORS_AT_COMMIT
        override fun getMoreResults() = false
        override fun getMoreResults(current: Int) = false
    
         override fun setFloat(parameterIndex: Int, x: Float) {
              pushIntoParameters(parameterIndex, x.toString())
          }
        override fun setLong(parameterIndex: Int, x: Long) {
              pushIntoParameters(parameterIndex, x.toString())
        }
        override fun setDouble(parameterIndex: Int, x: Double) {
              pushIntoParameters(parameterIndex, x.toString())
        }
        override fun setInt(parameterIndex: Int, x: Int) {
              pushIntoParameters(parameterIndex, x.toString())
        }
        override fun setString(parameterIndex: Int, x: String?) {
            pushIntoParameters(parameterIndex, x)
        }
        override fun setTimestamp(parameterIndex: Int, x: Timestamp?) {
              pushIntoParameters(parameterIndex, x.toString())
        }
    
        private fun pushIntoParameters(index: Int, value: String?) {
            if (index <= 0) throw SQLException("Invalid position for parameter ($index)")
            this.parameters[index] = value
        }
    
        ... ...
    }
    

    可以清楚的看到,在这里主要是用 Map 来保存需要替换的值,然后在执行的时候将真实的参数替换进问号中。然后对于执行 SQL 的地方,依然调用了抽象类里的 executeForResultexecuteForResultSet 方法。


    好了,现在我们已经完成了 StatementPreparedStatement,你可能要问了,能不能跑起代码看看效果呀?答案是否定的,因为还没有做好完整的准备,我们还需要一些东西,下面这个也很关键,是 ResultSet

    其实这也是 JDBC 坑的一个地方,通常情况下我们可能会希望写一点代码就运行起来看看效果,但是写 JDBC 驱动时偏偏不能,还是要先完整实现才可以。

    一个标准的 ResultSet 实现如下:

    
    class MyResultSet : ResultSet {
    
        private var isClosed = false
        private var position = -1
        private lateinit var fields: List<String>
        private lateinit var result: List<List<String>>
    
        constructor(jsonString: String) {
            MyResultSetUtil.jsonToResultData(jsonString) { f, l ->
                fields = f
                result = l
            }
        }
    
        constructor(fields: List<String>, list: List<List<String>>) {
            this.fields = fields
            this.result = list
        }
    
        override fun findColumn(columnLabel: String) = fields.indexOf(columnLabel)
        override fun getStatement(): Statement? = null
        override fun getWarnings(): SQLWarning? = null
    
        override fun beforeFirst() {
            checkIfClosed()
            position = -1
        }
    
        override fun close() {
            isClosed = true
        }
    
        override fun isFirst(): Boolean {
            checkIfClosed()
            return position == 0
        }
    
        override fun isLast(): Boolean {
            checkIfClosed()
            return position == result.size - 1
        }
    
        override fun last(): Boolean {
            position = result.size - 1
            return result.isNotEmpty()
        }
    
        override fun isAfterLast(): Boolean {
            checkIfClosed()
            return position >= result.size
        }
    
        override fun relative(rows: Int): Boolean {
            checkIfClosed()
            return if (rows + position in 1 until result.size) {
                position += rows
                true
            } else {
                false
            }
        }
    
        override fun absolute(row: Int): Boolean {
            checkIfClosed()
            return if (row in 1 until result.size) {
                position = row
                true
            } else {
                false
            }
        }
    
        override fun next(): Boolean {
            checkIfClosed()
            return if (position < result.size - 1) {
                position++
                true
            } else {
                false
            }
        }
    
        override fun first(): Boolean {
            checkIfClosed()
            position = 0
            return result.isNotEmpty()
        }
    
        override fun afterLast() {
            checkIfClosed()
            position = result.size
        }
    
        override fun previous(): Boolean {
            checkIfClosed()
            return if (position > 1) {
                position--
                true
            } else {
                false
            }
        }
    
        override fun isBeforeFirst(): Boolean {
            checkIfClosed()
            return position < 0
        }
    
        override fun getFloat(columnIndex: Int) = result[position][columnIndex].toFloat()
        override fun getFloat(columnLabel: String) = result[position][findColumn(columnLabel)].toFloat()
        override fun wasNull() = false
        override fun getRow() = position + 1
        override fun getType() = ResultSet.TYPE_SCROLL_INSENSITIVE
        override fun getString(columnIndex: Int) = result[position][columnIndex]
        override fun getString(columnLabel: String) = result[position][findColumn(columnLabel)]
        override fun getLong(columnIndex: Int) = result[position][columnIndex].toLong()
        override fun getLong(columnLabel: String) = result[position][findColumn(columnLabel)].toLong()
        override fun getTimestamp(columnIndex: Int): Timestamp = Timestamp.valueOf(result[position][columnIndex])
        override fun getTimestamp(columnLabel: String): Timestamp = Timestamp.valueOf(result[position][findColumn(columnLabel)])
        override fun getDouble(columnIndex: Int) = result[position][columnIndex].toDouble()
        override fun getDouble(columnLabel: String) = result[position][findColumn(columnLabel)].toDouble()
        override fun getInt(columnIndex: Int) = result[position][columnIndex].toInt()
        override fun getInt(columnLabel: String) = result[position][findColumn(columnLabel)].toInt()
    
        override fun isClosed() = isClosed
        override fun getFetchSize() = result.size
        override fun getConcurrency() = ResultSet.CONCUR_READ_ONLY
    
        override fun clearWarnings() {
            checkIfClosed()
        }
    
        override fun getFetchDirection() = ResultSet.TYPE_SCROLL_INSENSITIVE
    
        private fun checkIfClosed() {
            if (isClosed()) throw SQLException()
        }
        
        ... ...
    }
    

    这个看起来就有点复杂了,主要是对游标的移动和获取值的操作,同样的,这里也有一个自定义的函数 MyResultSetUtil.jsonToResultData,用于将 json 字符串转换为二维数组。这也就意味着我们在这里已经决定了数据的传递方式,以是 json 作为媒介的。

    转换函数的实现如下:

    fun jsonToResultData(jsonString: String, callback:(fields: List<String>, data: List<List<String>>) -> Unit) {
        val fields = getFields(jsonString)
        val data = mutableListOf<List<String>>()
        JSONArray(jsonString).forEach { _, obj -> data.add(fields.map { obj.get(it).toString() }) }
        callback(fields, data)
    }
    
      private fun getFields(jsonString: String) = try {
        JSONArray(jsonString).run { if (length() > 0) getJSONObject(0).keySet().toList() else listOf() }
    } catch (th: Throwable) {
        throw SQLException("Cannot get result fields.")
    }
    

    最后是补全驱动所需的另两个文件,分别是 DatabaseMetaDataResultSetMetaData

    其实这两个 MetaData 都可以什么都不填,因为基本上用不到,只是 JDBC 标准说必须实现,所以才予以实现,通常的处理方法是将其中所有的方法都标记为“不支持”:

    throw SQLFeatureNotSupportedException()
    

    像这样就可以了。


    好了,是不是现在就想跑起代码来看看效果?我们还有最后一步,还记得上面提到的 IO 对象不,现在来实现这个对象,以完成对数据的请求。当然了,在这里我们使用的是写死的假数据:

    object MyTestRequset {
        var LOCAL_TEST = false
    
        private val SAMPLEDATA = """[{"id":1, "name":"test1", "age":10},{"id":2, "name":"test2", "age":20},{"id":3, "name":"test3", "age":30},{"id":4, "name":"test4", "age":40},{"id":5, "name":"test5", "age":50}]"""
    
        @TestOnly
        fun localTestInternalRequest(sql: String) = if (sql.contains("select ")) SAMPLEDATA else "1"
    }
    
    class MyIO(private val prop: Map<String, String>) {
        fun internalExecuteQuery(sql: String) = try {
            MyResultSet(internalRequest(sql))
        } catch (th: Throwable) {
            println("internalExecuteQuery error: $th")
            null
        } ?: throw SQLException("cannot parse ResultSet")
    
        fun internalExecuteUpdate(sql: String) = try {
            internalRequest(sql).toInt()
        } catch (th: Throwable) {
            println("internalExecuteUpdate error: $th")
            -1
        }
    
        private fun internalRequest(sql: String): String {
            if (MyTestRequset.LOCAL_TEST) return MyTestRequset.localTestInternalRequest(sql)
            TODO("获取数据的真实代码写在此处")
        }
    }
    
    

    好了,现在我们的代码已经完整了,可以运行看看效果,在此写一个 Testcase 来跑一下:

    class Test {
        @Test
        fun doTest() {
            MyTestRequset.LOCAL_TEST = true
            Class.forName("com.sample.MyDriver")
            DriverManager.getConnection("jdbc:myurl://0.0.0.0/sampledb", Properties().apply { setProperty(PROP_SCHEMA, "http") }).use { conn ->
                conn.prepareStatement("select * from Data").use { stmt ->
                    stmt.executeQuery().use { result ->
                        while (result.next()) {
                            println(result.getString("name"))
                        }
                    }
                }
                conn.prepareStatement("insert into Data(name) values (?)").use { stmt ->
                    stmt.setString(1, "23333")
                    println(stmt.executeUpdate())
                }
            }
        }
    }
    

    能顺利跑通就说明我们的驱动已经正常工作了。同样的,符合 JDBC 标准的驱动也可以被 myBatis 等框架加载并使用。


    好了,下面是大招,还记得上面的 MyIO 里有一个 TODO 吗?我们完全可以把对数据库的请求代理掉,让它成为一个远程的数据请求,代码如下:

    private fun internalRequest(sql: String): String {
        var ret: String? = null
        http {
            url = "${if (prop.containsKey(PROP_SCHEMA)) prop[PROP_SCHEMA] else "http"}://${prop[PROP_HOST]}${prop[PROP_PORT]}/${prop[PROP_PATH]}"
            method = HttpMethod.POST
            if (prop.containsKey(PROP_USER)) authenticatorUser = prop[PROP_USER]
            if (prop.containsKey(PROP_PASSWORD)) authenticatorPassword = prop[PROP_PASSWORD]
            postParam = mutableMapOf("sql" to sql)
            onSuccess { code, text, _ ->
                if (code != 200) throw SQLException("Remote execute SQL failed: $code")
                ret = text
            }
        }
        return ret ?: throw SQLException("Remote SQL result is null.")
    }
    

    同时,只需要使用 Ktor 写几行代码,跑起服务器,这一切都顺理成章了(还不会 Ktor 的小伙伴可以看我的 Ktor 从入门到放弃 系列)。

    服务端代码:

    fun Routing.ISCRouting() {
        post("/sampledb") {
            val sql = call.requestParameters()["sql"] ?: ""
            call.respondText { doRequestDb(sql) }
        }
    }
    

    doRequestDb 的过程中,就可以做各种骚操作了,如分库分表,权限控制等,在此就不赘述了,大家可以发挥自己的想象力。


    最后,最上面提到的那个生成 JDBC 驱动代码的工具,可以从我的 Github 下载 EasyJDBC 并编译,然后愉快的开发吧。

    相关文章

      网友评论

        本文标题:实现自己的 JDBC 驱动

        本文链接:https://www.haomeiwen.com/subject/ehdhtktx.html