美文网首页
自己实现一个简单Golang ORM函数库

自己实现一个简单Golang ORM函数库

作者: 小艾咪 | 来源:发表于2021-05-01 19:42 被阅读0次

    前言

    通过该项目,对go的反射有了更深入的了解。特意记录下。将要使用的sql驱动为github.com/go-sql-driver/mysql

    正文

    数据库初始化

    任何sql操作都离不开初始化,调用sql.Open(dbType,dataSource);即可初始化数据库。但需要注意的是该函数是golang官方的数据库规范接口其具体实现交由第三处理。所以需要在需要初始化的包中导入并初始化第三方包的想关函数。

    package db
    
    import (
        "database/sql"
        "fmt"
        "time"
    
        //init sql
        _ "github.com/go-sql-driver/mysql"
    )
    
    //DB db oprate instance
    var DB *sql.DB
    
    var sqlType = "mysql"
    var dataSource = "root:123456@tcp(localhost)/alming"
    
    func init() {
        DB, _ = sql.Open(sqlType, dataSource)
        DB.SetConnMaxLifetime(time.Minute * 3)
        DB.SetMaxOpenConns(10)
        DB.SetMaxIdleConns(10)
        if err := DB.Ping(); err == nil {
            fmt.Println("Connect success:")
        } else {
            fmt.Println("Connect fail:", err)
        }
    }
    
    

    单结果查询操作

    需要解决的问题是如何将Sql操作的结果集映射到struct中。首先看下常规查询操作

    //以下代码基本为伪代码,未经测试仅展示流程
    type User struct{
        Username string
        Password string
    }
    rows,err:=DB.Query("select * from user")
    user:=new(User)
    rows.Next()
    rows.Scan(&user.Username,&user.Password)
    

    可以看到,Scan()方法接收的是指针类型参数,所以说要创建一个指针容器用于存放结果集。那么有两个问题:容器内指针是什么类型,容器的大小又是多少。这时我们需要使用rows实例的另一个函数ColumnTypes()它返回一个[]*sql.ColumnType其数组内元素包含每列结果的数据库类型。有了数据库类型就可以根据数据库类型创建go类型参数。而该返回值的个数也就是我们需要创建的容器大小。详情见代码

    //aldb.go
    //获得结果集所有列信息以创建接收结果的容器
    rc, err := rows.ColumnTypes()
    if err != nil {
        log.Println("Get column types fail")
    }
    container := createContainer(rc)
    if !rows.Next() {
        return false
    }
    column, _ := rows.Columns()
    rows.Scan(container...)
    success := mapResult(container, column, rs)
    
    //dbutil.go
    func createContainer(columnTyes []*sql.ColumnType) (params []interface{}) {
        params = make([]interface{}, len(columnTyes))
        for i, ct := range columnTyes {
            params[i] = createSlot(ct.DatabaseTypeName())
        }
        return
    }
    //这里也是仅列出了常用的类型,如需扩展再进行类型添加
    func createSlot(dbType string) interface{} {
        switch dbType {
        case "INT", "TINYINT", "BIGINT":
            return new(int)
        case "MEDIUMINT":
        case "DOUBLE":
            return new(float32)
        case "DECIMAL":
        case "CHAR":
            return new(byte)
        case "VARCHAR", "TEXT", "LONGTEXT":
            return &sql.NullString{String: "", Valid: true}
        case "BIT":
            return new(interface{})
        case "DATE":
            return &sql.NullString{String: "", Valid: false}
        case "DATETIME":
            return &sql.NullString{String: "", Valid: false}
        case "TIMESTAMP":
            return &sql.NullString{String: "", Valid: false}
        }
        return nil
    }
    

    这里有一个坑就是想要映射为golang的string类型时需要使用sql.NullString,否则当驱动扫描到一个值为NULL的列时将不会继续扫描后面的结果将会获取不到

    另外单结果查询我们还需要判断结果集是否为多个,因为有些业务只允许返回一个结果集,返回多个视为错误。实现起来也非常简单

    if rows.Next() {
        panic("QueryOne except one result but get no more one")
    }
    

    多结果查询操作

    多结果查询与单结果类似,只是在单结果上多了一个for循环

    rc, err := rows.ColumnTypes()
    if err != nil {
        log.Println("Get column types fail")
    }
    
    column, _ := rows.Columns()
    var oneMoreSet bool = false
    for rows.Next() {
        container := createContainer(rc)
        err = rows.Scan(container...)
        if err != nil {
            panic("Scan rows error")
        }
        oneMoreSet = mapResult(container, column, rs)
    }
    

    结果集映射

    可以看到前文中mapResult(container, column, rs)即为结果集映射函数,多结果与单结果共用一个函数,内部通过if判断区分以写操作。在接下来的源码中您可能会看到toPascalCase(columns[i])函数,该函数是一个工具函数它将sql列命映射成为Golang命名规范的变量命方便使用反射。映射规则是将首字母大写,_后第一个字母大写,其源码为

    func toPascalCase(src string) string {
        var dst = make([]uint8, 0)
        if src[0] > 96 && src[0] < 123 {
            dst = append(dst, src[0]-32)
        } else {
            dst = append(dst, src[0])
        }
        for i := 1; i < len(src); {
            if src[i] == '_' {
                if src[0] > 96 && src[0] < 123 {
                    dst = append(dst, src[i+1]-32)
                }
                i += 2
            } else {
                dst = append(dst, src[i])
                i++
    
            }
        }
        return string(dst)
    }
    

    然后继续看映射部分

    //mapResult 将sql rows扫描到的数据填入给定的结构中(结构体或slice)
    //container :单条结果容器,columns 结果集对应数据库中的列名,value
    //被映射对象
    func mapResult(container []interface{}, columns []string, value reflect.Value) bool {
        var slot reflect.Value
        var arr = make([]reflect.Value, 0)
        //判断待映射类型,结构以与slice分别处理
        if value.Elem().Kind() == reflect.Struct {
            slot = value.Elem()
        } else {
            //slice内数据类型的实例
            slot = reflect.New(value.Type().Elem().Elem()).Elem()
        }
        var oneMoreSet = false
        //遍历一行结果集找到其在结构体中的位置并赋值
        for i, v := range container {
            //找到对应结构体的属性
            slotField := slot.FieldByName(toPascalCase(columns[i]))
            if slotField.CanSet() {
                switch value := v.(type) {
                case *int:
                    //只有与其结构体类型匹配才赋值
                    if slotField.Kind() == reflect.Int {
                        slotField.SetInt(int64(*value))
                    }
                    oneMoreSet = true
                case *string:
                    if slotField.Kind() == reflect.String {
                        slotField.SetString(*value)
                    }
                    oneMoreSet = true
                case *sql.NullString:
                    if slotField.Kind() == reflect.String {
                        slotField.SetString(value.String)
                    }
                    oneMoreSet = true
                }
            }
        }
        //如果被映射对象是slice也就是多结果集映射要通过反射将映射出的
        //结构体实例追加到结果集中
        if value.Elem().Kind() == reflect.Slice {
            arr = append(arr, slot)
            added := reflect.Append(value.Elem(), arr...)
            value.Elem().Set(added)
        }
        return oneMoreSet
    }
    

    插入更新操作

    这部分我实现了一个自定义SQL格式,使用时需按该格式编写sql。规定:sql中参数都使用:数据库列名代替,它看起来是下面这样

    update user set username=:username where id=:id
    

    使用时会像下面这样

    u := user{
        Id:       2,
        Username: "alming_update",
    }
    Exec(&u, "update user set username=:username where id=:id")
    

    其内部实现原理也非常简单,直接看源码

    //Exec excute sql with the params in the struct you give
    func Exec(structure interface{}, sqlStr string) (success bool) {
        rs := reflect.ValueOf(structure)
        pointTo := rs.Elem()
        //自定义sql 表达式中 ?由[]:变量名]代替,找到这些变量名并由反射根据改名称获取所给
        //结构体实例当中的数据作为参数传递给Exec函数
        reg, _ := regexp.Compile(`:[a-zA-z_]+`)
        regFind := reg.FindAllString(sqlStr, -1)
        //通过反射创建参数列表的容器
        params := make([]interface{}, len(regFind))
        //通过自定义sql表达式获取sql
        SQLParsed := reg.ReplaceAllString(sqlStr, "?")
        //通过自定义sql中:找到对应的参数
        for i, sqlArgs := range regFind {
            parseArg := strings.TrimPrefix(sqlArgs, `:`)
            fieldName := toPascalCase(parseArg)
            field := pointTo.FieldByName(fieldName)
            switch field.Kind() {
            case reflect.Int:
                //将参数添加到参数容器中
                params[i] = field.Int()
            case reflect.String:
                params[i] = field.String()
            case reflect.Float32, reflect.Float64:
                params[i] = field.Float()
            }
        }
        var res sql.Result
        var err error
        if len(params) > 0 {
            res, err = DB.Exec(SQLParsed, params...)
        } else {
            res, err = DB.Exec(SQLParsed)
        }
        if err == nil {
            rowAf, _ := res.RowsAffected()
            return rowAf > 0
        }
        return false
    }
    

    关于一对多问题

    该操作实现的过于笨重且限制较多,就不班门弄斧了。感兴趣可以看下源码。

    func QueryOneToMany(slice interface{}, sqlStr string, outPk string, inPk string, params ...interface{}) (resMatched bool) {
        defer catchPanic()
        rs := reflect.ValueOf(slice)
        pointTo := rs.Elem()
        if pointTo.Kind() != reflect.Slice {
            panic("QueryOne must to map to a slice,please check your structure parameter")
        }
        var rows *sql.Rows
        var err error
        if len(params) == 0 {
            rows, err = DB.Query(sqlStr)
        } else {
            rows, err = DB.Query(sqlStr, params...)
        }
        if err != nil {
            log.Println("An error occerred when exec query sql", err)
        }
    
        rc, err := rows.ColumnTypes()
        if err != nil {
            log.Println("Get column types fail")
        }
    
        column, _ := rows.Columns()
        var allRows = make([][]interface{}, 0)
        for rows.Next() {
            container := createContainer(rc)
            err = rows.Scan(container...)
            if err != nil {
                panic("Scan rows error")
            }
            allRows = append(allRows, container)
        }
        //outPk,对应“一”的主键,inPk对应“多”的主键
        mapRes(allRows, column, rs, 0, outPk, inPk)
        //别忘改
        return true
    }
    
    //mapRes 将查询的结果集按一对多形式映射到结构当中
    //allRows 所有结果集,columns 结果集对应数据库中的列名,value
    //被映射对象,height工具属性与可变参数pk配合使用,pk(primary
    //key)设计目的是为了兼容QueryOne与Query的结果集映射。实际
    //这两个方法有单独的映射函数
    func mapRes(allRows [][]interface{}, columns []string, value reflect.Value, height int, pk ...string) {
        in := value.Elem()
        inType := in.Type().Elem()
        var inSlot reflect.Value
        var inSlotName string
        //查找给定结构的slice属性并为其
        for i := 0; i < inType.NumField(); i++ {
            if inType.Field(i).Type.Kind() == reflect.Slice {
                //记录改属性属性名方便之后通过反射获取改属性并为其赋值
                inSlotName = inType.Field(i).Name
                inSlot = reflect.New(inType.Field(i).Type)
                mapRes(allRows, columns, inSlot, height+1, pk...)
            }
        }
        //mark为一个标识,以sql primary key为map,通过它标识同一元素是否被重复扫描
        mark := make(map[interface{}]byte)
        //主键在column中索引位置,方便获取主键值并配合mark判断是否重复扫描
        var pkIdx = -1
        if len(pk) > 0 {
            pkIdx = getColIndex(columns, pk[height])
        }
        var arr = make([]reflect.Value, 0)
        for _, row := range allRows {
            if mark[pkValue(row[pkIdx])] == 1 {
                continue
            }
            outSlot := reflect.New(inType).Elem()
            var oneMoreSet = false
            for i, v := range row {
                slot := outSlot.FieldByName(toPascalCase(columns[i]))
                if slot.CanSet() {
                    switch setValue := v.(type) {
                    case *int:
                        if slot.Kind() == reflect.Int {
                            slot.SetInt(int64(*setValue))
                            oneMoreSet = true
                        }
                    case *string:
                        if slot.Kind() == reflect.String {
                            slot.SetString(*setValue)
                            oneMoreSet = true
                        }
                    case *sql.NullString:
                        if slot.Kind() == reflect.String {
                            slot.SetString(setValue.String)
                            oneMoreSet = true
                        }
                    }
                }
            }
            slot := outSlot.FieldByName(inSlotName)
            if slot.CanSet() {
                slot.Set(inSlot.Elem())
            }
            if oneMoreSet {
                if len(pk) > 0 {
                    mark[pkValue(row[pkIdx])] = 1
                }
            }
            arr = append(arr, outSlot)
        }
        added := reflect.Append(in, arr...)
        in.Set(added)
    }
    
    func getColIndex(colunms []string, col string) int {
        for idx, item := range colunms {
            if item == col {
                return idx
            }
        }
        return -1
    }
    func pkValue(pkContent interface{}) interface{} {
        switch v := pkContent.(type) {
        case *int:
            return *v
        case *byte:
            return *v
        case *float32:
            return *v
        case *string:
            return *v
        case *sql.NullString:
            return v.String
        default:
            return nil
        }
    }
    

    总结

    关于Go反射

    1. go反射不像java,go必须在已有实例上进行反射。

    2. go使用反射修改实例内容时需要反射的内容必须为指针类型(可通过CanSet()判断该属性是否可以赋值),并且修改时需要调用Elem()方法获取其指向的元素。

    3. 反射slice添加元素比较复杂详情见代码。

    4. Elem()返回指针所指向的元素,如果是数组类型则返回其内部元素的类型。

    5. 可以通过reflect.New()创建新的实例,但与第一条不冲突(创建实例所需的类型参数由反射已有实例获得)

    附录

    源代码:alming_backend

    一些平台禁止外链 https://github.com/ALMing530/alming_backend

    进入该项目db文件夹下查看

    相关文章

      网友评论

          本文标题:自己实现一个简单Golang ORM函数库

          本文链接:https://www.haomeiwen.com/subject/mvxxdltx.html