说到 Druid,你第一个想到的是啥?
是这个?
还是这个?
又或者是这个?
反正今天我不说他们,今天要说的是阿里开源的 Druid 项目,利用其 SQL Parser 来进行数据库的迁移是一件体验很不错的事。
那现就开始吧,这次来实现 SQL 语句的翻译,来尝试将 Oracle 的建表语句翻译为 MySQL 的。
赶紧新建一个工程,具体怎么建的不多说了,然后加入对 Druid 的引用:
compile 'com.alibaba:druid:1.1.16'
在 Druid 内有标准的 AST(点击查看什么是 AST) 和各种 SQL 方言的解析支持,其中 Parser
和 各种 Statement
即是其关键所在。
那就开始写一些代码吧,我们有以下 Oracle 建表语句:
create table T_ITEM (
process_session_id number CONSTRAINT pspk PRIMARY KEY,
org_id NUMBER(19) DEFAULT 0,
comp_id NUMBER(19) NOT NULL,
item_id NUMBER(19) CONSTRAINT psuk UNIQUE,
qty NUMBER(19,2),
item_desc VARCHAR2(255),
tisp TIMESTAMP(6) DEFAULT SYSDATE,
create_by NUMBER(19),
create_time TIMESTAMP(6),
modi_by NUMBER(19),
modi_time TIMESTAMP(6)
);
要将 SQL 语句解成 Statement
,需要用到 Parser
,如以上 Oracle 语句,则可以用以下代码来解:
val parser = OracleStatementParser(SQLSTR)
val stmt = parser.parseStatementList()[0] as OracleCreateTableStatement
需要注意的是,Druid 可以对整个 SQL 脚本文件进行解析,并根据内容生成多个 Statement,所以 parse 的结果是一个数组,这是正确的情况。而此处由于只有一个 create table,并不需要对其他内容作解析,因此直接拿了第 0 个元素,并且我们已经知道了该 Statement 的类型为 OracleCreateTableStatement
。
对于 Statement,通用的命名规则是 <数据库类型><语句类型>Statement
,记住这个规则将很容易可以找到对应的 Statement 类。
然后就需要从这个 Statement 里拿到所需的信息,如表名和字段信息:
// 获取表名
val tableName = stmt.tableSource.name.simpleName
// 获取字段信息
val tableFields = stmt.tableElementList
这里的 tableFields
还需要进一步的解析,以获取字段的数据类型,以及各种约束,默认值等。
tableFields.forEach {
it as SQLColumnDefinition
// 获取字段名称
it.name
// 获取字段数据类型
it.dataType
// 获取字段的约束
it.constraints
// 获取字段的默认值
it.defaultExpr
}
当然了,这里获取到的数据类型是 Oracle 的,它无法被 MySQL 识别,因此就需要对数据类型进行转换,好在 Druid 已经为我们提供了这样的函数:
val mysqlDataType = SQLTransformUtils.transformOracleToMySql(it.dataType)
这样就得到了一个作用于 MySQL 的数据类型。
最后,对于默认值,有一个地方需要注意的,就是 DEFAULT SYSDATE
,在翻译的过程中,Druid 并不认识 Oracle 的各种函数,所以我们需要手动进行转换:
fun convertDefaultValue(type: String, def: String) =
if (type.startsWith("datetime", ignoreCase = true) && def.equals("SYSDATE", ignoreCase = true)) {
"NOW()"
} else if (type.startsWith("timestamp", ignoreCase = true)) {
"CURRENT_TIMESTAMP()"
} else {
def
}
好了,那么准备工作都做好了,现在就可以写出转换的函数了,直接给完整代码:
fun convertCreateTable(o: OracleCreateTableStatement) =
SQLUtils.formatMySql("create table ${o.tableSource.name.simpleName} (${o.tableElementList.toMySQLFieldDeclare()});")
private fun List<SQLTableElement>.toMySQLFieldDeclare() = joinToString(",") {
it as SQLColumnDefinition
val ct = SQLTransformUtils.transformOracleToMySql(it.dataType).toString()
var str = "${it.name} $ct"
if (it.constraints.isNotEmpty()) {
it.constraints.forEach { c ->
when (c) {
is SQLColumnPrimaryKey -> str += " primary key"
is SQLColumnUniqueKey -> str += " unique"
is SQLNotNullConstraint -> str += " not null"
}
}
}
if (it.defaultExpr != null) str += " default ${convertDefaultValue(ct, it.defaultExpr.toString())}"
str
}
现在就可以运行程序了,得到最终结果:
fun main(args: Array<String>) {
val parser = OracleStatementParser(SQLSTR)
val stmt = parser.parseStatementList()[0] as OracleCreateTableStatement
println(convertCreateTable(stmt))
}
得到的结果如下:
CREATE TABLE T_ITEM (
process_session_id decimal(38) PRIMARY KEY,
org_id bigint DEFAULT 0,
comp_id bigint NOT NULL,
item_id bigint UNIQUE,
qty decimal(19, 2),
item_desc varchar(255),
tisp datetime(6) DEFAULT NOW(),
create_by bigint,
create_time datetime(6),
modi_by bigint,
modi_time datetime(6)
);
把这个语句放进 MySQL 里执行,报告成功,即告转换完成。
网友评论