03GORM源码解读

作者: 刷漆猫咪 | 来源:发表于2020-01-14 18:02 被阅读0次

    简介

    GORM 源码解读, 基于 v1.9.11 版本.

    模型交互

    前面已经研究过模型是如何定义并被解析的了, 这次看一下模型是如何和数据库交互的.

    package main
    
    import (
      "github.com/jinzhu/gorm"
      _ "github.com/jinzhu/gorm/dialects/sqlite"
    )
    
    type Product struct {
      gorm.Model
      Code string
      Price uint
    }
    
    func main() {
      db, err := gorm.Open("sqlite3", "test.db")
      if err != nil {
        panic("failed to connect database")
      }
      defer db.Close()
    
      // Migrate the schema
      db.AutoMigrate(&Product{})
    
      // 创建
      db.Create(&Product{Code: "L1212", Price: 1000})
    
      // 读取
      var product Product
      db.First(&product, 1) // 查询id为1的product
      db.First(&product, "code = ?", "L1212") // 查询code为l1212的product
    
      // 更新 - 更新product的price为2000
      db.Model(&product).Update("Price", 2000)
    
      // 删除 - 删除product
      db.Delete(&product)
    }
    

    AutoMigrate

    当定义好模型之后, 第一步是使用 AutoMigrate 合并模型:

    db.AutoMigrate(&Product{})
    

    看一下它的源码:

    // AutoMigrate run auto migration for given models, will only add missing fields, won't delete/change current data
    func (s *DB) AutoMigrate(values ...interface{}) *DB {
        db := s.Unscoped()
        for _, value := range values {
            db = db.NewScope(value).autoMigrate().db
        }
        return db
    }
    

    内部是对每个传递的参数调用了 db.NewScope(value).autoMigrate().

    那具体是如何合并的呢?

    func (scope *Scope) autoMigrate() *Scope {
        tableName := scope.TableName()
        quotedTableName := scope.QuotedTableName()
    
        if !scope.Dialect().HasTable(tableName) {
            scope.createTable()
        } else {
            for _, field := range scope.GetModelStruct().StructFields {
                if !scope.Dialect().HasColumn(tableName, field.DBName) {
                    if field.IsNormal {
                        sqlTag := scope.Dialect().DataTypeOf(field)
                        scope.Raw(fmt.Sprintf("ALTER TABLE %v ADD %v %v;", quotedTableName, scope.Quote(field.DBName), sqlTag)).Exec()
                    }
                }
                scope.createJoinTable(field)
            }
            scope.autoIndex()
        }
        return scope
    }
    

    中间的 if 部分的代码展示了两条路径. 如果表还没有创建, 直接创建就行了.

    否则就需要对模型中的每个字段进行操作, 如果列名不存在, 就需要变更表新增字段了.

    scope.Raw(fmt.Sprintf("ALTER TABLE %v ADD %v %v;", quotedTableName, scope.Quote(field.DBName), sqlTag)).Exec()
    

    SQL 语句是如何执行的, 先暂时不理会, 但从代码的形式上看算是挺简洁的, 直接使用 Raw 构造语句, Exec 执行.

    同时, 对于模型中的每个字段, 还要更新一遍连接表, scope.createJoinTable(field).

    在 for 循环处理完模型中的所有字段后, 再更新一遍索引, scope.autoIndex().

    总结起来, 自动合并主要做了这么几件事: 创建表, 添加新增的字段, 更新表的关系, 更新索引.

    createTable

    前面省略了创建表的具体过程, 来仔细看看表是如何创建的.

    func (scope *Scope) createTable() *Scope {
        var tags []string
        var primaryKeys []string
        var primaryKeyInColumnType = false
        for _, field := range scope.GetModelStruct().StructFields {
            if field.IsNormal {
                sqlTag := scope.Dialect().DataTypeOf(field)
    
                // Check if the primary key constraint was specified as
                // part of the column type. If so, we can only support
                // one column as the primary key.
                if strings.Contains(strings.ToLower(sqlTag), "primary key") {
                    primaryKeyInColumnType = true
                }
    
                tags = append(tags, scope.Quote(field.DBName)+" "+sqlTag)
            }
    
            if field.IsPrimaryKey {
                primaryKeys = append(primaryKeys, scope.Quote(field.DBName))
            }
            scope.createJoinTable(field)
        }
    
        var primaryKeyStr string
        if len(primaryKeys) > 0 && !primaryKeyInColumnType {
            primaryKeyStr = fmt.Sprintf(", PRIMARY KEY (%v)", strings.Join(primaryKeys, ","))
        }
    
        scope.Raw(fmt.Sprintf("CREATE TABLE %v (%v %v)%s", scope.QuotedTableName(), strings.Join(tags, ","), primaryKeyStr, scope.getTableOptions())).Exec()
    
        scope.autoIndex()
        return scope
    }
    

    这就是构建 SQL 创建表的过程, 主要的过程是这行代码:

    scope.Raw(fmt.Sprintf("CREATE TABLE %v (%v %v)%s", scope.QuotedTableName(), strings.Join(tags, ","), primaryKeyStr, scope.getTableOptions())).Exec()
    

    前面的过程主要是遍历模型的字段, 获取每个字段的 sqlTag, 并加入 tags 中:

    tags = append(tags, scope.Quote(field.DBName)+" "+sqlTag)
    

    带有双引号的列名加上空格加上 sqlTag.

    这个过程中还涉及到了主键的判断, 不过感觉这部分有点坑, 因为
    sqlTag := scope.Dialect().DataTypeOf(field) 的实现取决于每种数据库对 DataTypeOf 的具体实现.

    issues 2270 显示出现多个 primary key,
    使用的是如下的模型定义, 数据库使用了 sqlite3:

    type Permission struct {
        ID   int64  `gorm:"AUTO_INCREMENT;column:id;primary_key"`
        Name string `gorm:"column:name;type:varchar;unique;not null"`
        Idx  int64  `gorm:"AUTO_INCREMENT"`
    }
    

    虽然这个模型定义中只指定了一个 primary_key, 但结果 Idx 也变成了 primary_key:

    [2019-01-19 19:40:30]  table "permission" has more than one primary key
    
    [2019-01-19 19:40:30]  [0.14ms]  CREATE TABLE "permission" ("id" integer primary key autoincrement,"name" varchar NOT NULL UNIQUE,"idx" integer primary key autoincrement )
    [0 rows affected or returned ]
    

    原因只有一个, 它使用了 AUTO_INCREMENT 选项, 而在 sqlite3 的 DataTypeOf 实现中:

    case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uintptr:
      if s.fieldCanAutoIncrement(field) {
        field.TagSettingsSet("AUTO_INCREMENT", "AUTO_INCREMENT")
        sqlType = "integer primary key autoincrement"
      } else {
        sqlType = "integer"
      }
    case reflect.Int64, reflect.Uint64:
      if s.fieldCanAutoIncrement(field) {
        field.TagSettingsSet("AUTO_INCREMENT", "AUTO_INCREMENT")
        sqlType = "integer primary key autoincrement"
      } else {
        sqlType = "bigint"
      }
    

    AUTO_INCREMENT 选项导致了返回的结果中存在 primary key.

    我怀疑这是个 bug. 因为在后续有对是否是主键的判断, 并添加 primaryKeyStr.

    if field.IsPrimaryKey {
      primaryKeys = append(primaryKeys, scope.Quote(field.DBName))
    }
    
    var primaryKeyStr string
    if len(primaryKeys) > 0 && !primaryKeyInColumnType {
      primaryKeyStr = fmt.Sprintf(", PRIMARY KEY (%v)", strings.Join(primaryKeys, ","))
    }
    

    我觉得 sqlType 不应该返回关于 primary key 的信息.
    要设置主键, 可以在后面的 primaryKeyStr 中进行.

    好了, 对于主键的讨论就此告一段落了.

    合并表和创建表的过程中都有 createJoinTable, 但因为关系实现还没有深入研究, 先忽略吧.

    callbacks

    增删改查都和 DB 结构体中的 callbacks 有关:

    // DB contains information for current db connection
    type DB struct {
      ...
        // global db
        parent        *DB
        callbacks     *Callback
        dialect       Dialect
        singularTable bool
      ...
    }
    

    看一下 Create 方法的代码:

    // Create insert the value into database
    func (s *DB) Create(value interface{}) *DB {
        scope := s.NewScope(value)
        return scope.callCallbacks(s.parent.callbacks.creates).db
    }
    

    在新的 scope 中调用了 callCallbacks 方法, 里面的参数是 s.parent.callbacks.creates.
    parent 的类型也是 *DB, 算是继承.

    继续挖掘 callCallbacks:

    func (scope *Scope) callCallbacks(funcs []*func(s *Scope)) *Scope {
        defer func() {
            if err := recover(); err != nil {
                if db, ok := scope.db.db.(sqlTx); ok {
                    db.Rollback()
                }
                panic(err)
            }
        }()
        for _, f := range funcs {
            (*f)(scope)
            if scope.skipLeft {
                break
            }
        }
        return scope
    }
    

    使用了 defer 下的 recover 模式, 以前介绍过这个模式, 不再深入.

    callCallbacks 的参数其实是个函数的切片, 然后依次调用所有的函数, 除非 scope.skipLeft 为 true.

    看过了调用的方式, 让我们来看看 Callback 到底是什么.

    // Callback is a struct that contains all CRUD callbacks
    //   Field `creates` contains callbacks will be call when creating object
    //   Field `updates` contains callbacks will be call when updating object
    //   Field `deletes` contains callbacks will be call when deleting object
    //   Field `queries` contains callbacks will be call when querying object with query methods like Find, First, Related, Association...
    //   Field `rowQueries` contains callbacks will be call when querying object with Row, Rows...
    //   Field `processors` contains all callback processors, will be used to generate above callbacks in order
    type Callback struct {
        logger     logger
        creates    []*func(scope *Scope)
        updates    []*func(scope *Scope)
        deletes    []*func(scope *Scope)
        queries    []*func(scope *Scope)
        rowQueries []*func(scope *Scope)
        processors []*CallbackProcessor
    }
    

    Callback 里包含了很多的函数切片, 用于增删改查. 注释已经解释的很清楚了.

    关注一下 CallbackProcessor, 这是用于按序生成所有 callbacks 的.

    // CallbackProcessor contains callback informations
    type CallbackProcessor struct {
        logger    logger
        name      string              // current callback's name
        before    string              // register current callback before a callback
        after     string              // register current callback after a callback
        replace   bool                // replace callbacks with same name
        remove    bool                // delete callbacks with same name
        kind      string              // callback type: create, update, delete, query, row_query
        processor *func(scope *Scope) // callback handler
        parent    *Callback
    }
    
    // Create could be used to register callbacks for creating object
    //     db.Callback().Create().After("gorm:create").Register("plugin:run_after_create", func(*Scope) {
    //       // business logic
    //       ...
    //
    //       // set error if some thing wrong happened, will rollback the creating
    //       scope.Err(errors.New("error"))
    //     })
    func (c *Callback) Create() *CallbackProcessor {
        return &CallbackProcessor{logger: c.logger, kind: "create", parent: c}
    }
    
    // Update could be used to register callbacks for updating object, refer `Create` for usage
    func (c *Callback) Update() *CallbackProcessor {
        return &CallbackProcessor{logger: c.logger, kind: "update", parent: c}
    }
    
    // Delete could be used to register callbacks for deleting object, refer `Create` for usage
    func (c *Callback) Delete() *CallbackProcessor {
        return &CallbackProcessor{logger: c.logger, kind: "delete", parent: c}
    }
    
    // Query could be used to register callbacks for querying objects with query methods like `Find`, `First`, `Related`, `Association`...
    // Refer `Create` for usage
    func (c *Callback) Query() *CallbackProcessor {
        return &CallbackProcessor{logger: c.logger, kind: "query", parent: c}
    }
    
    // RowQuery could be used to register callbacks for querying objects with `Row`, `Rows`, refer `Create` for usage
    func (c *Callback) RowQuery() *CallbackProcessor {
        return &CallbackProcessor{logger: c.logger, kind: "row_query", parent: c}
    }
    

    Callback 有各种方法来创建不同类型的 CallbackProcessor.

    // After insert a new callback after callback `callbackName`, refer `Callbacks.Create`
    func (cp *CallbackProcessor) After(callbackName string) *CallbackProcessor {
        cp.after = callbackName
        return cp
    }
    
    // Before insert a new callback before callback `callbackName`, refer `Callbacks.Create`
    func (cp *CallbackProcessor) Before(callbackName string) *CallbackProcessor {
        cp.before = callbackName
        return cp
    }
    

    AfterBefore 更新了 CallbackProcessor 上特定的属性, 用于后续计算 callback 调用顺序.

    db.Callback().Create().After("gorm:create").Register("plugin:run_after_create", func(*Scope) {
      // business logic
      ...
    
      // set error if some thing wrong happened, will rollback the creating
      scope.Err(errors.New("error"))
    })
    

    注释上的例子是这样的, 继续看 Register 方法.

    // Register a new callback, refer `Callbacks.Create`
    func (cp *CallbackProcessor) Register(callbackName string, callback func(scope *Scope)) {
        if cp.kind == "row_query" {
            if cp.before == "" && cp.after == "" && callbackName != "gorm:row_query" {
                cp.logger.Print(fmt.Sprintf("Registering RowQuery callback %v without specify order with Before(), After(), applying Before('gorm:row_query') by default for compatibility...\n", callbackName))
                cp.before = "gorm:row_query"
            }
        }
    
        cp.name = callbackName
        cp.processor = &callback
        cp.parent.processors = append(cp.parent.processors, cp)
        cp.parent.reorder()
    }
    

    主要是设置了 cp 的 processor 属性, 并将该 cp 添加到了 cp.parent.processors 中.
    然后调用 cp.parent.reorder() 进行了重新排序.

    有注册方法, 当然也有对应的删除方法:

    // Remove a registered callback
    //     db.Callback().Create().Remove("gorm:update_time_stamp_when_create")
    func (cp *CallbackProcessor) Remove(callbackName string) {
        cp.logger.Print(fmt.Sprintf("[info] removing callback `%v` from %v\n", callbackName, fileWithLineNum()))
        cp.name = callbackName
        cp.remove = true
        cp.parent.processors = append(cp.parent.processors, cp)
        cp.parent.reorder()
    }
    

    设置 remove 属性为 true, 然后重新排序.

    替换的方法也是类似:

    // Replace a registered callback with new callback
    //     db.Callback().Create().Replace("gorm:update_time_stamp_when_create", func(*Scope) {
    //         scope.SetColumn("Created", now)
    //         scope.SetColumn("Updated", now)
    //     })
    func (cp *CallbackProcessor) Replace(callbackName string, callback func(scope *Scope)) {
        cp.logger.Print(fmt.Sprintf("[info] replacing callback `%v` from %v\n", callbackName, fileWithLineNum()))
        cp.name = callbackName
        cp.processor = &callback
        cp.replace = true
        cp.parent.processors = append(cp.parent.processors, cp)
        cp.parent.reorder()
    }
    

    还是看一下重新排序是如何进行的吧:

    // reorder all registered processors, and reset CRUD callbacks
    func (c *Callback) reorder() {
        var creates, updates, deletes, queries, rowQueries []*CallbackProcessor
    
        for _, processor := range c.processors {
            if processor.name != "" {
                switch processor.kind {
                case "create":
                    creates = append(creates, processor)
                case "update":
                    updates = append(updates, processor)
                case "delete":
                    deletes = append(deletes, processor)
                case "query":
                    queries = append(queries, processor)
                case "row_query":
                    rowQueries = append(rowQueries, processor)
                }
            }
        }
    
        c.creates = sortProcessors(creates)
        c.updates = sortProcessors(updates)
        c.deletes = sortProcessors(deletes)
        c.queries = sortProcessors(queries)
        c.rowQueries = sortProcessors(rowQueries)
    }
    

    上半部分只是分别归类, 具体还是要看 sortProcessors:

    // sortProcessors sort callback processors based on its before, after, remove, replace
    func sortProcessors(cps []*CallbackProcessor) []*func(scope *Scope) {
        var (
            allNames, sortedNames []string
            sortCallbackProcessor func(c *CallbackProcessor)
        )
    
        for _, cp := range cps {
            // show warning message the callback name already exists
            if index := getRIndex(allNames, cp.name); index > -1 && !cp.replace && !cp.remove {
                cp.logger.Print(fmt.Sprintf("[warning] duplicated callback `%v` from %v\n", cp.name, fileWithLineNum()))
            }
            allNames = append(allNames, cp.name)
        }
    
        sortCallbackProcessor = func(c *CallbackProcessor) {
            if getRIndex(sortedNames, c.name) == -1 { // if not sorted
                if c.before != "" { // if defined before callback
                    if index := getRIndex(sortedNames, c.before); index != -1 {
                        // if before callback already sorted, append current callback just after it
                        sortedNames = append(sortedNames[:index], append([]string{c.name}, sortedNames[index:]...)...)
                    } else if index := getRIndex(allNames, c.before); index != -1 {
                        // if before callback exists but haven't sorted, append current callback to last
                        sortedNames = append(sortedNames, c.name)
                        sortCallbackProcessor(cps[index])
                    }
                }
    
                if c.after != "" { // if defined after callback
                    if index := getRIndex(sortedNames, c.after); index != -1 {
                        // if after callback already sorted, append current callback just before it
                        sortedNames = append(sortedNames[:index+1], append([]string{c.name}, sortedNames[index+1:]...)...)
                    } else if index := getRIndex(allNames, c.after); index != -1 {
                        // if after callback exists but haven't sorted
                        cp := cps[index]
                        // set after callback's before callback to current callback
                        if cp.before == "" {
                            cp.before = c.name
                        }
                        sortCallbackProcessor(cp)
                    }
                }
    
                // if current callback haven't been sorted, append it to last
                if getRIndex(sortedNames, c.name) == -1 {
                    sortedNames = append(sortedNames, c.name)
                }
            }
        }
    
        for _, cp := range cps {
            sortCallbackProcessor(cp)
        }
    
        var sortedFuncs []*func(scope *Scope)
        for _, name := range sortedNames {
            if index := getRIndex(allNames, name); !cps[index].remove {
                sortedFuncs = append(sortedFuncs, cps[index].processor)
            }
        }
    
        return sortedFuncs
    }
    

    首先获取了所有 cp 的名字, 同时提示是否发现了重复. sortedNames 里保存排序好的名字.

    // getRIndex get right index from string slice
    func getRIndex(strs []string, str string) int {
        for i := len(strs) - 1; i >= 0; i-- {
            if strs[i] == str {
                return i
            }
        }
        return -1
    }
    

    getRIndex 获取最右边的索引.

    看一下 sortCallbackProcessor 函数到底在做什么.

    里面有两个判断部分, 先看第一个部分:

    if c.before != "" { // if defined before callback
      if index := getRIndex(sortedNames, c.before); index != -1 {
        // if before callback already sorted, append current callback just after it
        sortedNames = append(sortedNames[:index], append([]string{c.name}, sortedNames[index:]...)...)
      } else if index := getRIndex(allNames, c.before); index != -1 {
        // if before callback exists but haven't sorted, append current callback to last
        sortedNames = append(sortedNames, c.name)
        sortCallbackProcessor(cps[index])
      }
    }
    

    分为两种情况, 如果 before callback 已经排序好了, 直接插在它的后面就行.

    如果 before callback 确实存在, 但还没有被排序, 就将当前名字直接放在 sortedNames 的最后.
    然后递归调用 sortCallbackProcessor(cps[index]), 这就是直接进入到 before callback 的排序中了.

    再看第二个部分:

    if c.after != "" { // if defined after callback
      if index := getRIndex(sortedNames, c.after); index != -1 {
        // if after callback already sorted, append current callback just before it
        sortedNames = append(sortedNames[:index+1], append([]string{c.name}, sortedNames[index+1:]...)...)
      } else if index := getRIndex(allNames, c.after); index != -1 {
        // if after callback exists but haven't sorted
        cp := cps[index]
        // set after callback's before callback to current callback
        if cp.before == "" {
          cp.before = c.name
        }
        sortCallbackProcessor(cp)
      }
    }
    

    其实和前面的逻辑差不多, 如果 after callback 已经排序好了, 直接插在它的前面就行.

    如果 after callback 确实存在, 会修改 after callback 的 before 属性, 设置为当前 callback.
    然后递归调用 sortCallbackProcessor(cp), 进入到 after callback 的排序中.

    // if current callback haven't been sorted, append it to last
    if getRIndex(sortedNames, c.name) == -1 {
      sortedNames = append(sortedNames, c.name)
    }
    

    还没保存就直接放到最后. sortCallbackProcessor 的内容就是这样.

    for _, cp := range cps {
      sortCallbackProcessor(cp)
    }
    

    开始排序. 等排序完了之后, sortedNames 就完成了:

    var sortedFuncs []*func(scope *Scope)
    for _, name := range sortedNames {
      if index := getRIndex(allNames, name); !cps[index].remove {
        sortedFuncs = append(sortedFuncs, cps[index].processor)
      }
    }
    
    return sortedFuncs
    

    将那些不是 remove 状态的 callback, 依次添加到 sortedFuncs 中.

    最后还有一个 Get 方法用于获取注册的回调:

    // Get registered callback
    //    db.Callback().Create().Get("gorm:create")
    func (cp *CallbackProcessor) Get(callbackName string) (callback func(scope *Scope)) {
        for _, p := range cp.parent.processors {
            if p.name == callbackName && p.kind == cp.kind {
                if p.remove {
                    callback = nil
                } else {
                    callback = *p.processor
                }
            }
        }
        return
    }
    

    现在, 我们应该已经清楚了回调函数是如何注册并排序的了, 以及如何按名称获取单个回调函数.

    实际注册流程

    前面只是讲解了理论上的定义, 看一下实际上是在哪里注册的.

    DB 在初始化的时候, 即 Open 方法调用了如下的语句:

    db = &DB{
      db:        dbSQL,
      logger:    defaultLogger,
      callbacks: DefaultCallback,
      dialect:   newDialect(dialect, dbSQL),
    }
    

    这个 DefaultCallback 的定义如下:

    // DefaultCallback default callbacks defined by gorm
    var DefaultCallback = &Callback{}
    

    一开始我也是有点慌, 这只是个空定义, 肯定有地方初始化的. 扫了一眼目录就明白了.

    callback_create.go 文件下定义了 create 方面的注册流程.

    // Define callbacks for creating
    func init() {
        DefaultCallback.Create().Register("gorm:begin_transaction", beginTransactionCallback)
        DefaultCallback.Create().Register("gorm:before_create", beforeCreateCallback)
        DefaultCallback.Create().Register("gorm:save_before_associations", saveBeforeAssociationsCallback)
        DefaultCallback.Create().Register("gorm:update_time_stamp", updateTimeStampForCreateCallback)
        DefaultCallback.Create().Register("gorm:create", createCallback)
        DefaultCallback.Create().Register("gorm:force_reload_after_create", forceReloadAfterCreateCallback)
        DefaultCallback.Create().Register("gorm:save_after_associations", saveAfterAssociationsCallback)
        DefaultCallback.Create().Register("gorm:after_create", afterCreateCallback)
        DefaultCallback.Create().Register("gorm:commit_or_rollback_transaction", commitOrRollbackTransactionCallback)
    }
    

    结合文档,
    看一下 BeforeSaveBeforeCreate 是如何实现的.

    当你定义一个模型时, 可以在这个模型上实现 BeforeSaveBeforeCreate 之类的方法,
    这些方法会在恰当的时候被调用.

    func (u *User) BeforeSave() (err error) {
      if !u.IsValid() {
        err = errors.New("can't save invalid data")
      }
      return
    }
    
    func (u *User) AfterCreate(scope *gorm.Scope) (err error) {
      if u.ID == 1 {
        scope.DB().Model(u).Update("role", "admin")
      }
      return
    }
    

    上面是官方文档上的例子. 在前面我们在注释中看到了如何手动注册一个回调函数,
    类似于 DefaultCallback.Create().Register("gorm:begin_transaction", beginTransactionCallback),
    但如何实现调用模型上定义的方法呢?

    看一下 beforeCreateCallback 函数:

    // beforeCreateCallback will invoke `BeforeSave`, `BeforeCreate` method before creating
    func beforeCreateCallback(scope *Scope) {
        if !scope.HasError() {
            scope.CallMethod("BeforeSave")
        }
        if !scope.HasError() {
            scope.CallMethod("BeforeCreate")
        }
    }
    

    原来是通过 scope.CallMethod 方法实现的, 传递特定的方法名称就能调用该方法了.

    // CallMethod call scope value's method, if it is a slice, will call its element's method one by one
    func (scope *Scope) CallMethod(methodName string) {
        if scope.Value == nil {
            return
        }
    
        if indirectScopeValue := scope.IndirectValue(); indirectScopeValue.Kind() == reflect.Slice {
            for i := 0; i < indirectScopeValue.Len(); i++ {
                scope.callMethod(methodName, indirectScopeValue.Index(i))
            }
        } else {
            scope.callMethod(methodName, indirectScopeValue)
        }
    }
    

    绕了一圈, 继续看 callMethod 的代码:

    func (scope *Scope) callMethod(methodName string, reflectValue reflect.Value) {
        // Only get address from non-pointer
        if reflectValue.CanAddr() && reflectValue.Kind() != reflect.Ptr {
            reflectValue = reflectValue.Addr()
        }
    
        if methodValue := reflectValue.MethodByName(methodName); methodValue.IsValid() {
            switch method := methodValue.Interface().(type) {
            case func():
                method()
            case func(*Scope):
                method(scope)
            case func(*DB):
                newDB := scope.NewDB()
                method(newDB)
                scope.Err(newDB.Error)
            case func() error:
                scope.Err(method())
            case func(*Scope) error:
                scope.Err(method(scope))
            case func(*DB) error:
                newDB := scope.NewDB()
                scope.Err(method(newDB))
                scope.Err(newDB.Error)
            default:
                scope.Err(fmt.Errorf("unsupported function %v", methodName))
            }
        }
    }
    

    这些灵活的方式都是靠反射实现的, 关键代码是 methodValue := reflectValue.MethodByName(methodName).

    switch 可以看到, 方法可以有不同的签名:

    switch method := methodValue.Interface().(type) {
    case func():
      method()
    case func(*Scope):
      method(scope)
    case func(*DB):
      newDB := scope.NewDB()
      method(newDB)
      scope.Err(newDB.Error)
    case func() error:
      scope.Err(method())
    case func(*Scope) error:
      scope.Err(method(scope))
    case func(*DB) error:
      newDB := scope.NewDB()
      scope.Err(method(newDB))
      scope.Err(newDB.Error)
    default:
      scope.Err(fmt.Errorf("unsupported function %v", methodName))
    }
    

    所以, 实际上这都可以看作是 reflect 的大型示范使用例子.

    createCallback

    其他的钩子函数不看了, 具体看一下当插入单条数据时都在干什么:

    // createCallback the callback used to insert data into database
    func createCallback(scope *Scope) {
        if !scope.HasError() {
            defer scope.trace(scope.db.nowFunc())
    
            var (
                columns, placeholders        []string
                blankColumnsWithDefaultValue []string
            )
    
            for _, field := range scope.Fields() {
                if scope.changeableField(field) {
                    if field.IsNormal && !field.IsIgnored {
                        if field.IsBlank && field.HasDefaultValue {
                            blankColumnsWithDefaultValue = append(blankColumnsWithDefaultValue, scope.Quote(field.DBName))
                            scope.InstanceSet("gorm:blank_columns_with_default_value", blankColumnsWithDefaultValue)
                        } else if !field.IsPrimaryKey || !field.IsBlank {
                            columns = append(columns, scope.Quote(field.DBName))
                            placeholders = append(placeholders, scope.AddToVars(field.Field.Interface()))
                        }
                    } else if field.Relationship != nil && field.Relationship.Kind == "belongs_to" {
                        for _, foreignKey := range field.Relationship.ForeignDBNames {
                            if foreignField, ok := scope.FieldByName(foreignKey); ok && !scope.changeableField(foreignField) {
                                columns = append(columns, scope.Quote(foreignField.DBName))
                                placeholders = append(placeholders, scope.AddToVars(foreignField.Field.Interface()))
                            }
                        }
                    }
                }
            }
    
            var (
                returningColumn = "*"
                quotedTableName = scope.QuotedTableName()
                primaryField    = scope.PrimaryField()
                extraOption     string
                insertModifier  string
            )
    
            if str, ok := scope.Get("gorm:insert_option"); ok {
                extraOption = fmt.Sprint(str)
            }
            if str, ok := scope.Get("gorm:insert_modifier"); ok {
                insertModifier = strings.ToUpper(fmt.Sprint(str))
                if insertModifier == "INTO" {
                    insertModifier = ""
                }
            }
    
            if primaryField != nil {
                returningColumn = scope.Quote(primaryField.DBName)
            }
    
            lastInsertIDReturningSuffix := scope.Dialect().LastInsertIDReturningSuffix(quotedTableName, returningColumn)
    
            if len(columns) == 0 {
                scope.Raw(fmt.Sprintf(
                    "INSERT %v INTO %v %v%v%v",
                    addExtraSpaceIfExist(insertModifier),
                    quotedTableName,
                    scope.Dialect().DefaultValueStr(),
                    addExtraSpaceIfExist(extraOption),
                    addExtraSpaceIfExist(lastInsertIDReturningSuffix),
                ))
            } else {
                scope.Raw(fmt.Sprintf(
                    "INSERT %v INTO %v (%v) VALUES (%v)%v%v",
                    addExtraSpaceIfExist(insertModifier),
                    scope.QuotedTableName(),
                    strings.Join(columns, ","),
                    strings.Join(placeholders, ","),
                    addExtraSpaceIfExist(extraOption),
                    addExtraSpaceIfExist(lastInsertIDReturningSuffix),
                ))
            }
    
            // execute create sql
            if lastInsertIDReturningSuffix == "" || primaryField == nil {
                if result, err := scope.SQLDB().Exec(scope.SQL, scope.SQLVars...); scope.Err(err) == nil {
                    // set rows affected count
                    scope.db.RowsAffected, _ = result.RowsAffected()
    
                    // set primary value to primary field
                    if primaryField != nil && primaryField.IsBlank {
                        if primaryValue, err := result.LastInsertId(); scope.Err(err) == nil {
                            scope.Err(primaryField.Set(primaryValue))
                        }
                    }
                }
            } else {
                if primaryField.Field.CanAddr() {
                    if err := scope.SQLDB().QueryRow(scope.SQL, scope.SQLVars...).Scan(primaryField.Field.Addr().Interface()); scope.Err(err) == nil {
                        primaryField.IsBlank = false
                        scope.db.RowsAffected = 1
                    }
                } else {
                    scope.Err(ErrUnaddressable)
                }
            }
        }
    }
    

    首先, 内部的第一个 for 循环遍历了所有的字段, 并更新了开头定义的三个切片.

    for _, field := range scope.Fields() {
      if scope.changeableField(field) {
        if field.IsNormal && !field.IsIgnored {
          if field.IsBlank && field.HasDefaultValue {
            blankColumnsWithDefaultValue = append(blankColumnsWithDefaultValue, scope.Quote(field.DBName))
            scope.InstanceSet("gorm:blank_columns_with_default_value", blankColumnsWithDefaultValue)
          } else if !field.IsPrimaryKey || !field.IsBlank {
            columns = append(columns, scope.Quote(field.DBName))
            placeholders = append(placeholders, scope.AddToVars(field.Field.Interface()))
          }
        } else if field.Relationship != nil && field.Relationship.Kind == "belongs_to" {
          for _, foreignKey := range field.Relationship.ForeignDBNames {
            if foreignField, ok := scope.FieldByName(foreignKey); ok && !scope.changeableField(foreignField) {
              columns = append(columns, scope.Quote(foreignField.DBName))
              placeholders = append(placeholders, scope.AddToVars(foreignField.Field.Interface()))
            }
          }
        }
      }
    }
    

    然后就是获取并设置一些信息:

    var (
      returningColumn = "*"
      quotedTableName = scope.QuotedTableName()
      primaryField    = scope.PrimaryField()
      extraOption     string
      insertModifier  string
    )
    

    等信息都获取完了, 就开始构造插入语句了:

    if len(columns) == 0 {
      scope.Raw(fmt.Sprintf(
        "INSERT %v INTO %v %v%v%v",
        addExtraSpaceIfExist(insertModifier),
        quotedTableName,
        scope.Dialect().DefaultValueStr(),
        addExtraSpaceIfExist(extraOption),
        addExtraSpaceIfExist(lastInsertIDReturningSuffix),
      ))
    } else {
      scope.Raw(fmt.Sprintf(
        "INSERT %v INTO %v (%v) VALUES (%v)%v%v",
        addExtraSpaceIfExist(insertModifier),
        scope.QuotedTableName(),
        strings.Join(columns, ","),
        strings.Join(placeholders, ","),
        addExtraSpaceIfExist(extraOption),
        addExtraSpaceIfExist(lastInsertIDReturningSuffix),
      ))
    }
    

    最后执行 sql 语句:

    // execute create sql
    if lastInsertIDReturningSuffix == "" || primaryField == nil {
      if result, err := scope.SQLDB().Exec(scope.SQL, scope.SQLVars...); scope.Err(err) == nil {
        // set rows affected count
        scope.db.RowsAffected, _ = result.RowsAffected()
    
        // set primary value to primary field
        if primaryField != nil && primaryField.IsBlank {
          if primaryValue, err := result.LastInsertId(); scope.Err(err) == nil {
            scope.Err(primaryField.Set(primaryValue))
          }
        }
      }
    } else {
      if primaryField.Field.CanAddr() {
        if err := scope.SQLDB().QueryRow(scope.SQL, scope.SQLVars...).Scan(primaryField.Field.Addr().Interface()); scope.Err(err) == nil {
          primaryField.IsBlank = false
          scope.db.RowsAffected = 1
        }
      } else {
        scope.Err(ErrUnaddressable)
      }
    }
    

    这里的第一个判断条件是和 lastInsertIDReturningSuffix 有关的, 只有 PostgreSQL 会返回非空的字符串.

    var userid int
    err := db.QueryRow(`INSERT INTO users(name, favorite_fruit, age)
        VALUES('beatrice', 'starfruit', 93) RETURNING id`).Scan(&userid)
    

    PostgreSQL 中不支持 LastInsertId() 方法, 要获取 ID 需要像上面这样调用.
    参考 PostgreSQL Queries.

    所以执行方式有所不同.

    这样, createCallback 回调就看完了, 插入数据的过程也知道了.

    总结

    在这一部分里, 主要看了数据表是如何创建和合并的, 以及钩子函数是如何注册并排序的, 以及何时被调用的.

    相关文章

      网友评论

        本文标题:03GORM源码解读

        本文链接:https://www.haomeiwen.com/subject/sqvaactx.html