美文网首页
go-sql-driver 源码解析

go-sql-driver 源码解析

作者: HackerZGZ | 来源:发表于2017-10-19 01:18 被阅读0次

Intro

最近正在给 mysql 封装一个库,顺带研究一下 go-mysql-driver 这个库的源码实现。

Buffer.go

buffer 是一个用于给 数据库连接 (net.Conn) 进行缓冲的一个数据结构,其结构为:

type buffer struct {
    buf     []byte     // 缓冲池中的数据
    nc      net.Conn   // 负责缓冲的数据库连接对象
    idx     int        // 已读数据索引
    length  int        // 缓冲池中未读数据的长度
    timeout time.Duration // 数据库连接的超时设置
}

可以看到,因为 数据库连接 (net.Conn) 在通信的时候是 同步 的。而为了让其能够 同时 读/写 ,所以实现了 buffer 这个数据结构,通过该 buffer 进行数据缓冲还能实现 零拷贝 ( zero-copy-ish ) 。

其函数分别有:

  • newBuffer(nc net.Conn) buffer :创建并返回一个 buffer

  • (*buffer) readNext(need int) ([]byte, error) :读取并返回未读数据的 need 位,如果 need 大于 bufferlength ,就会调用 fill(need int) errorbuffer进行 扩容

  • (*buffer) fill(need int) error :对 buffer 进行 (need/defaultBufSize) 的倍数扩容,并在 timeout 时间结束前从 buffer.nc 中读取 need 长度的数据。

  • (*buffer) takeBuffer(length int) []byte :读取 bufferlength 长度的数据(只包含已读),如果 buffer.length > 0 ,即还有未读数据,则立即返回 nil 。如果需要读取的长度大于 buffer 的容量,则会进行扩容。

  • (*buffer) takeSmallBuffer(length int) []byte :读取保证不超过 defaultBufSize 长度的数据的快捷函数(只包含已读),如果 buffer.length > 0 ,即还有未读数据,则立即返回 nil

  • (*buffer) takeCompleteBuffer() []byte : 读取全部的 buffer 数据(只包含已读),如果 buffer.length > 0 ,即还有未读数据,则立即返回 nil

Collations.go

collations 包含了 MySQL 所有支持的 字符集 格式,并支持通过 COLLATION_NAME 返回其字符集 ID

如果需要查询 MySQL 支持的 字符集 格式,可以使用 SELECT COLLATION_NAME, ID FROM information_schema.COLLATIONS 语句获取。

Dsn.go

DSN数据源名称 (Data Source Name) ,是 驱动程序连接数据库的变量信息 ,简而言之就是根据你连接的不同数据库使用对应的连接信息。

通常,数据库的连接配置就是在这里定义的:

// Config 基本的数据库连接信息
type Config struct {
    User         string            // Username
    Passwd       string            // Password (requires User)
    Net          string            // Network type
    Addr         string            // Network address (requires Net)
    DBName       string            // Database name
    Params       map[string]string // Connection parameters
    Collation    string            // Connection collation
    Loc          *time.Location    // Location for time.Time values
    TLSConfig    string            // TLS configuration name
    tls          *tls.Config       // TLS configuration
    Timeout      time.Duration     // Dial timeout
    ReadTimeout  time.Duration     // I/O read timeout
    WriteTimeout time.Duration     // I/O write timeout

    AllowAllFiles           bool // 允许文件使用 LOAD DATA LOCAL INFILE 导入数据库
    AllowCleartextPasswords bool // 支持明文密码客户端
    AllowOldPasswords       bool // 允许使用不可靠的旧密码
    ClientFoundRows         bool // 返回匹配的行数而不是受影响的行数
    ColumnsWithAlias        bool // 将表名前置在列名
    InterpolateParams       bool // 将占位符插入查询的SQL字符串
    MultiStatements         bool // 允许一条语句多次查询
    ParseTime               bool // 格式化时间值为 time.Time 变量
    Strict                  bool // 将 warnings 返回 errors
}

这都是一些常见的配置项,就此略过。

该文件有两个公共函数支持 ConfigDSN 之间转换。

  • (*Config)FormatDSN() string

  • ParseDSN(dsn string) (*Config, error)

Errors.go

errors 定义了 LoggerMySQLErrorMySQLWarning 等数据结构。

Logger

复用了 Go 原生的 log 包,并将其中的输出重定向至控制台的 标准错误

type Logger interface {
  Print(v ...interface{})
}

var errLog = Logger(log.New(os.Stderr, "[mysql]", log.Ldate|log.Ltime|log.Lshortfile))

func SetLogger(logger Logger) error { // 当然,你也可以使用自定义的错误 Logger
  if logger == nil {
    return errors.New("logger is nil")
  }
  errLog =logger
  return nil
}

MySQLError

MySQLError 则简单定义了 MySQL 输出的错误的结构。

type MySQLError struct {
    Number  uint16
    Message string
}

MySQLWarning

MySQLWarning 则有些不一样,它需要从 MySQL 中进行一次 查询 ,以获取所有的警告信息,所以该包也定义了 MySQLWarningslice 结构。

type MySQLWarning struct {
    Level string
    Code string
    Message string
}

type MySQLWarnings []MySQLWarning

func (mc *mysqlConn) getWarnings() (err error) {
  rows, err := mc.Query("SHOW WARNINGS", nil)
  // handle err
  
  // initzation MySQLWarnings
  
  for {
    err = rows.Next(values)
    switch err {
      case nil:
        warning := MySQLWarning{}
        
      if raw, ok := values[0].([]byte); ok {
          warning.Level = string(raw)
      }else {
          warning.Level = fmt.Sprintf("%s", values[0])
      }
      
      if raw, ok := values[1].([]byte); ok {
        warning.Code = string(raw)
      } else {
        warning.Code = fmt.Sprintf("%s", values[1])
      }
      
      if raw, ok := values[2].([]byte); ok {
        warning.Message = string(raw)
      } else {
        warning.Message = fmt.Sprintf("%s", values[0])
      }

      warnings = append(warnings, warning)
    }
    
    case io.EOF:
        return warnings
    
    default:
        rows.Close() // 值得注意的是,如果该函数没有 case 运行 default ,该 rows 就不会被默认关闭,就会占用连接池中的一个连接,是否应该使用 `defer rows.Close() ` 避免该情况?
        return
  }
}

Infile.go

前面也有提到 MySQL 在导入大型文件的时候,需要使用 LOAD DATA LOCAL INFILE 的形式进行导入,而该 infile.go 就是实现该协议的代码。

本包在实现的 LOAD DATA 的时候提供了两种方式进行导入:

  • 最常见的,使用服务器的文件路径,如 /data/students.csv ,下文命名其为 文件路径注册器

  • 最通用的,使用实现了 io.Reader 接口的数据结构,通过返回该数据结构的数据进行导入,如 bytes os.file 等,下文命名其为 Reader 接口注册器

在实现该功能的时候,注册器 的实现是用名字作为 Key 的 Map ,为了避免 Map读写竞态 ,需要对其配置一个读写锁。

var (
    fileRegister        map[string]bool     // 文件路径注册器
    fileRegisterLock    sync.RWMutex        // 文件路径注册器读写锁
    readerRegister      map[string]func() io.Reader // Reader 接口注册器
    readerRegisterLock  sync.RWMutex                // Reader 接口注册器读写锁  
)

除了对两个注册器的 注册 以及 注销 函数,还有一个需要分析的一个函数:

(mc *mysqlConn) handleInFileRequest(name string) (err error)

通过传入 文件路径 或者 Reader 名称 就可以将数据发往 MySQL 了。

func (mc *mysqlConn) handleInFileRequest(name string) (err error) {
    packSize := 16 * 1024 // 16KB is small enough for disk readahead and large enough for TCP
    if mc.maxWriteSize < packSize { // 设置发往 MySQL 的数据块大小
        packSize = mc.maxWriteSize
    }
  
    // 获取 文件 或 Reader 的数据,并将其赋值到 rdr 中
    // var rdr io.Reader
  
    // send context packets
    if err != nil {
        data := make([]byte, 4+packetSize) // 需要留 4 个 byte 给协议使用
        var n int
        for err == nil {
            n, err = rdr.Read(data[4:]) // 将数据存入 data 的 [4:] 中
            if n > 0 {
                if ioErr := mc.writePacket(data[:4+n]); ioErr != nil { // 将 data 数据发往 MySQL
                    return ioErr
                }
            }
        }
        if err == io.EOF { // rdr 中的数据读完了
            err = nil
        }
    }
  
    // send empty packet (termination)
    if data == nil {
        data = make([]byte, 4)
    }
    if ioErr := mc.writePacket(data[:4]); ioErr != nil { // 告诉 MySQL 文件发送完毕
        return ioErr
    }

    // read OK packet
    if err == nil { // 一切正常结束
        return mc.readResultOK()
    }

    mc.readPacket() // 如果中途出错,将错误信息读取到 mysqlConn 中,并返回该错误
    return err
}

到此,infile.go 的实现已经整理完毕了,可以看到, 作者 在实现这个功能的时候还是做了一些优化的,比如 map Lazy initsend packet size limited 等。而我们通过分析规范的源码包,能够提升自己的编码水平。

Packets.go

接下来就要深入到 MySQL 的通信协议中了,官方的 通信协议文档 非常齐全,我在这里只将一些基础的,我后面分析源码会用到的协议分析下,如果有兴趣,可以到官方文档处进行查阅。

Protocol Basics

基础数据类型

MySQL 通信的基本数据类型有两种, IntegerString

  • Integer : 分别有 12348 个字节长度的类型,使用小端传输。

  • String : 分别有 固定长度字符串(协议规定)NULL结尾字符串(长度不固定)长度编码字符串(长度不固定)

报文协议

报文分为 消息头 以及 消息体,而 消息头 由 3 字节的 消息长度 以及 1 字节的 序号 sequence (新客户端由 0 开始)组成,消息体 则由 消息长度 的字节组成。

  • 3 字节的 消息长度 最大值为 0xFFFFFF ,即为 16 MB - 1 byte ,这就意味着,如果整个消息(不包括消息头)的长度大于 16MB - 1byte - 4byte 大小时,消息就会被分包。

  • 1 字节的 序号 在每次新的客户端发起请求时,以 0 开始,依次递增 1 ,如果消息需要分包, 序号 会随着分包的数量递增。而在一次应答中, 客户端会校验服务器 返回序号 是否与 发送序号 一致,如果不一致,则返回错误异常。

协议类型

  • handshake : 发起连接

  • auth : 登录权限校验

  • ok | error : 返回结果状态 *

  • ok : 首字节为 0 (0x00

  • error : 首字节为 255 (0xff

  • resultset : 结果集

  • header

  • field

  • eof

  • row

  • command package : 命令

在整个 MySQL 发起交互的过程如下图所示:

mysql connect

在了解这些 MySQL 基础协议知识后,我们再来看 packages.go 的源码就轻松多了。

源码

先来看看 readPacket ,结合上面的知识点应该非常好理解。

func (mc *mysqlConn) readPacket() ([]byte, error) {
    var payload []byte
    for { // for 循环是为了读取有可能分片的数据
        // Read package header
        data, err := mc.buf.readNext(4) // 从 buffer 缓冲器中读取 4 字节的 header
        if err != nil { // 如果读取发生异常,则关闭连接,并返回一个错误连接的异常
            errLog.Print(err)
            mc.Close()
            return nil, driver.ErrBadConn 
        }
      
        // Packet Length [24 bit]
        pktLen := int(uint32(data[0]) | uint32(data[1])<<8 | uint32(data[2])<<16) // 读取 3 字节的消息长度
      
        if pktLen < 1 {
            // 如上所示,关闭连接,并返回一个错误连接的异常
        }
      
        // Check Packet Sync [8 bit]
        if data[3] != mc.sequence { // 判断服务端返回的序号是否与客户端一致
            if data[3] > mc.sequence {
                return nil, ErrPktSyncMul // 如果服务端返回序号大于客户端的序号,则有可能是在一次请求中做了多次操作
            }
            return nil, ErrPktSync // 返回序号不一致错误
        }
        mc.sequence++ // 本次序号匹配相符,为了匹配下一次请求,先将序号自增1
      
        data, err := mc.buf.readNext(pktLen) // 读取 消息长度 的数据
        if err != nil {
            // 如上所示,关闭连接,并返回一个错误连接的异常
        }
      
        isLastPacket := (pktLen < maxPacketSize) // 如果是最后一个数据包,必然小于 maxPacketSize (16MB - 1byte)
        
        // Zero allocations for non-splitting packets
        if isLastPacket && payload == nil { // 无分包情况,立即返回
            return data, nil
        }

        payload = append(payload, data...)

        if isLastPacket { // 如果是最后一个包,读取完毕后返回
            return payload, nil
        }
      
        // 还有未读数据,开始下一次循环
    }
}

下面来看下结合 握手报文协议 来看下客户端向服务端发起请求的 readInitPacket

mysql handshack protocol
func (mc *mysqlConn) readInitPacket() ([]byte, error) {
    data, err := mc.readPacket() // 调用上面的函数读取服务端返回的数据
    if err != nil {
        return nil, err
    }
  
    if data[0] == iERR { // iERR = 0xff  消息体的第一个字节返回 0xff ,则意味着 error package
        return nil, mc.handleErrorPacket(data)
    }
  
    // protocol version [1 byte]
    if data[0] < minProtocolVersion { // 判断是否是兼容的协议版本
        return nil, fmt.Errorf(
            "unsupported protocol version %d. Version %d or higher is required",
            data[0],
            minProtocolVersion,
        )
    }
  
    // server version [null terminated string]
    // connection id [4 bytes]
    pos := 1 + bytes.IndexByte(data[1:], 0x00) + 1 + 4 // 读取 NULL (0x00)为结尾的字符串,跳过服务器线程 ID
  
    // first part of the password cipher [8 bytes]
    cipher := data[pos : pos+8] // 获取挑战随机数
  
    // (filler) always 0x00 [1 byte]
    pos += 8 + 1
  
    // capability flags (lower 2 bytes) [2 bytes]
    mc.flags = clientFlag(binary.LittleEndian.Uint16(data[pos : pos+2])) // 获取服务器权能标识
    if mc.flags&clientProtocol41 == 0 { // 说明 MySQL 服务器不支持高于 41 版本的协议
        return nil, ErrOldProtocol
    }
    if mc.flags&clientSSL == 0 && mc.cfg.tls != nil { // 说明 MySQL 服务器需要 SSL 加密,但是客户端没有配置 SSL
        return nil, ErrNoTLS
    }
    pos += 2 // 指针向后两位
  
    if len(data) > pos {
        // 指针跳过标志位
        pos += 1 + 2 + 2 + 1 + 10

        // second part of the password cipher [mininum 13 bytes],
        // where len=MAX(13, length of auth-plugin-data - 8)
        //
        // The web documentation is ambiguous about the length. However,
        // according to mysql-5.7/sql/auth/sql_authentication.cc line 538,
        // the 13th byte is "\0 byte, terminating the second part of
        // a scramble". So the second part of the password cipher is
        // a NULL terminated string that's at least 13 bytes with the
        // last byte being NULL.
        //
        // The official Python library uses the fixed length 12
        // which seems to work but technically could have a hidden bug.
        cipher = append(cipher, data[pos:pos+12]...)

        // TODO: Verify string termination
        // EOF if version (>= 5.5.7 and < 5.5.10) or (>= 5.6.0 and < 5.6.2)
        // \NUL otherwise
        //
        //if data[len(data)-1] == 0 {
        //  return
        //}
        //return ErrMalformPkt

        // make a memory safe copy of the cipher slice
        var b [20]byte
        copy(b[:], cipher)
        return b[:], nil
    }

    // make a memory safe copy of the cipher slice
    var b [8]byte // 返回 8 字节的挑战随机数
    copy(b[:], cipher)
    return b[:], nil
}

除了上面解析的两个函数, packages.go 还有 initialisation process / result packages / prepared statements 等协议的 写入/读取 ,有兴趣的读者可以结合上面的知识点自行阅读。

Driver.go

接下来就要分析一些比较重要的代码了,比如接下来要讲的 driver.go ,它主要负责与 MySQL 数据库进行各种协议的连接,并返回该连接。可以说它才是最基础、最核心的功能。

不过首先我们需要看下 database/sql 包中的 Driver 接口需要如何实现:

// database/sql/driver/driver.go

// 数据库驱动
type Driver interface {
  Open(name string) (Conn, error)
}

// ...

// 非并发安全数据库连接
type Conn interface {
  // 返回一个绑定到 sql 的准备语句
  Prepare(query string) (Stmt, error)
  
  // 关闭该连接,并标记为不再使用,停止所有准备语句和事务
  // 因为 database/sql 包维护了一个空闲的连接池,并且在空闲连接过多的时候会自动调用 Close ,所以驱动程序包不需要显式调用该函数
  Close() error
  
  // 开始并返回一个新的事务,而新的事务与旧的连接没有任何关联
  Begin() (Tx, error)
}

根据 database/sql 提供的 Driver 接口, go-sql-driver/mysql 实现了自己的 数据库驱动 结构:

type MySQLDriver struct{}

func (d MySQLDriver) Open(dsn string) (driver.Conn, error) {
  mc := &mysqlConn {
      // set max value
  }
  mc.cfg = ParseDSN(dsn) // 通过解析 DSN 设置 MySQL 连接的配置

  // set parseTime and strict
  // ...
  
  // connect to server
  if dial, ok := dials[mc.cfg.Net]; ok { // 根据 地址 以及 协议类型,尝试连接上服务器
    mc.netConn, err = dial(mc.cfg.Addr)
  } else { // 连接服务器失败,尝试重连
    nd := net.Dialer{Timeout: mc.cfg.Timeout}
    mc.netConn, err := nd.Dial(mc.cfg.Net, mc.cfg.Addr)
  }
  if err !=  nil { // 重试失败,返回异常
      return nil, err
  }
  
  // Enable TCP Keepalives on TCP connections
  if tc, ok := mc.netConn.(*net.Conn); ok { // tcp 连接类型转换
    if err := tc.SetKeepAlive(true); err != nil {
      // Don't send COM_QUIT before handshake.
      mc.netConn.Close() // 如果设置长连接失败,返回异常之前一定要记得将连接断开
      mc.netConn = nil
      return nil, err
    }
  }
  
  mc.buff = newBuff(mc.netConn) // 生成一个带缓冲的 buffer,如上面 buffer.go 中所说
  
  // set I/O timeout
  // ...
  
  // Reading Handshake Initialization Packet
  cipher, err := mc.readInitPacket() // 发起数据库首次握手
  if err != nil {
    mc.cleanup() // 将当前 mysqlConn 对象销毁,后面我们会说这个函数
    return nil, err
  }
  
  // Send Client Authentication Packet
  if err = mc.writeAuthPacket(cipher); err != nil { // 向数据库发送登录信息校验
    mc.cleanup()
    return nil, err
  }
}

connection.go

终于要讲到这个包的核心数据结构 mysqlConn 了,可以说,驱动的所有功能几乎都围绕着这个数据结构,我们先来看看它的结构:

type mysqlConn struct {
    buf              buffer     // buffer 缓冲器
    netConn          net.Conn   // 网络连接
    affectedRows     uint64     // sql 执行成功影响行数
    insertId         uint64     // sql 添加成功最新的主键 ID
    cfg              *Config    // dsn 中的 基础配置
    maxPacketAllowed int        // 允许的最大报文的字节长度,最大不能超过 (16MB - 1byte)
    maxWriteSize     int        // 允许最大的写入字节长度,最大不能超过 (16MB - 1byte)
    writeTimeout     time.Duration  // 执行 sql 的 超时时间
    flags            clientFlag     // 客户端状态标识
    status           statusFlag     // 服务端状态标识
    sequence         uint8          // 序号
    parseTime        bool           // 是否格式化时间
    strict           bool           // 是否使用严格模式
}

// driver.go
// 而创建一个 mysqlConn 连接需要通过 driver.go 中的 Open 函数,也说明 mysqlConn 实现了 driver.Conn 接口
func (d MySQLDriver) Open(dsn string) (driver.Conn, error) {
  mc := &mysqlConn{
      // ...
  }
  
  // ...
  
  return mc, nil
}

当一个新的客户端连接上服务器的时候 (三次握手结束,客户端进入 established 状态),需要先对 MySQL 服务器进行 会话的用户/系统环境变量 的设置。

// Handles parameters set in DSN after the connection is established
func (mc *mysqlConn) handleParams() (err error) {
    for param, val := range mc.cfg.Params { // Params: map[string]string
        switch param {
        // Charset
        case "charset": // 如果是字符集,则调用 SET NAMES 命令
            charsets := strings.Split(val, ",")
            for i := range charsets {
                // ignore errors here - a charset may not exist
                err = mc.exec("SET NAMES " + charsets[i])
                if err == nil {
                    break
                }
            }
            if err != nil {
                return
            }

        // System Vars
        default: // 执行系统环境变量设置
            err = mc.exec("SET " + param + "=" + val + "")
            if err != nil {
                return
            }
        }
    }
}

conntion.go 还负责 事务预处理语句执行/查询 的管理,但是基本都是往 mysqlConn 中发送 command package ,如:

// Begin 开启事务
func (mc *mysqlConn) Begin() (driver.Tx, error) {
    if mc.netConn == nil {
        errLog.Print(ErrInvalidConn)
        return nil, driver.ErrBadConn
    }
    err := mc.exec("START TRANSACTION")
    if err == nil {
        return &mysqlTx{mc}, err // 返回成功开启的事务,重用之前的连接
    }

    return nil, err
}

// Internal function to execute commands
func (mc *mysqlConn) exec(query string) error {
    // Send command
    err := mc.writeCommandPacketStr(comQurey, query)
    if err != nil {
        return err
    }
  
    // Read Result
    resLen, err := mc.readResultSetHeaderPacket() // 根据 data[0] 的值判断是否出错,如果没有错误,则返回消息体的长度
    if err == nil && resLen > 0 { // 存在有效消息体
        if err = mc.readUntilEOF(); err != nil { // 读取 columns
            return err
        }

        err = mc.readUntilEOF() // 读取 rows
    }

    return err
}

我想 conntion.go 中最重要的一个函数应该是 cleanup ,它负责将 连接关闭重置环境变量 等功能,但是该函数不能随意调用,它只有在 登录权限校验异常 时候才应该被调用,否则服务器在不知道客户端 被强行关闭 的情况下,依然会向该客户端发送消息,导致严重异常:

// Closes the network connection and unsets internal variables. Do not call this
// function after successfully authentication, call Close instead. This function
// is called before auth or on auth failure because MySQL will have already
// closed the network connection.
func (mc *mysqlConn) cleanup() {
    // Makes cleanup idempotent 保证函数的幂等性
    if mc.netConn != nil {
        if err := mc.netConn.Close(); err != nil { // Close 会尝试发送 comQuit command 到服务器
            errLog.Print(err)
        }
        mc.netConn = nil // 不管 Close 是否成功,必须将 netConn 清空
    }
    mc.cfg = nil
    mc.buf.nc = nil // 缓冲器中的 netConn 也要关闭
}

Result.go

每当 MySQL 返回一个 OK状态报文 ,该报文协议会携带上本次执行的结果 affectedRows 以及 insertId ,而 result.go 就包含着一个数据结构,用于存储本次的执行结果。

type mysqlResult struct {
    affectedRows int64
    insertId     int64
}

// 两个 getter
func (res *mysqlResult) LastInsertId() (int64, error) {
    return res.insertId, nil
}

func (res *mysqlResult) RowsAffected() (int64, error) {
    return res.affectedRows, nil
}

接下来我们看下在 conntion.go 中是怎么生成 mysqlResult 对象的:

// connect.go
func (mc *mysqlConn) Exec(query string, args []driver.Value) (driver.Result, error) {
  
    // ...
  
    err := exec(query)
    if err == nil {
        return &mysqlResult{ // 返回执行的结果
            affectedRows: int64(mc.affectedRows),
            insertId:     int64(mc.insertId),
        }, err
    }
    return nil, err
}

// exec 函数的解析可以返回上面 package.go 中浏览

// package.go
func (mc *mysqlConn) readResultSetHeaderPacket() (int, error) {
    data, err := mc.readPacket()
    if err == nil {
        switch data[0] {

        case iOK:
            return 0, mc.handleOkPacket(data) // 处理 OK 状态报文

        // ...
}

func (mc *mysqlConn) handleOkPacket(data []byte) error {
    var n, m int

    // 0x00 [1 byte]

    // Affected rows [Length Coded Binary]
    mc.affectedRows, _, n = readLengthEncodedInteger(data[1:])

    // Insert id [Length Coded Binary]
    mc.insertId, _, m = readLengthEncodedInteger(data[1+n:])

    // ...
}

Row.go

MySQL 执行 插入、更新、删除 等操作后,都会返回 Result ,但是 查询 返回的是 Rows ,我们先来看看 go-mysql-driver 驱动所实现的 接口 Rows 的接口描述:

// database/sql/driver/driver.go
// Rows 是执行查询返回的结果的 游标
type Rows interface {
    // Columns 返回列的名称,从 slice 的长度可以判断列的长度
    // 如果一个列的名称未知,则为该列返回一个空字符串
    Columns() []string
  
    // Close 关闭游标
    Close() error
  
    // Next 将下一行数据填充到 desc 切片中
    // 如果读取的是最后一行数据,应该返回一个 io.EOF 错误
    Next(desc []Value) error
}

type Value interface{} // Value is a value that drivers must be able to handle.

为什么我要说这是 go-mysql-driver 驱动所实现的 接口 Rows 呢?眼尖的同学应该已经看到了, Next 函数好像和我们平常见到的不一样啊!!

是的,因为我们平常使用的:

  • rows.Next()

  • rows.Scan(dest ...interface{}) error

等函数的对象 rows 并不是上面的 接口描述 Rows ,而是另一个封装的 同名数据结构 Rows ,它就在 database/sql 包中 :

// database/sql.go
type Rows struct {
    dc          *driverConn 
    releaseConn func(error)
    rowsi       driver.Rows // 接口描述的 Rows 藏在这!!!
    
    // 忽略其他字段,因为我们不分析这个包...
  
    // lastcols is only used in Scan, Next, and NextResultSet which are expected
    // not not be called concurrently.
    lastcols []driver.Value
}

我们跳过 database/sql 包中的 Rows 实现,其无非是提供了更多功能的一个结果集而已,让我们回到真正与数据库进行交互的 Rows 中进行源码分析。

go-sql-driver 实现的 mysqlRows 数据结构只实现了 Columns()Close() 两个行数,剩下的 Next(desc []driver.Value) 实现则交给了 MySQL 的两种结果集协议:

// rows.go

type mysqlField struct {
    tableName string
    name      string
    flags     fieldFlag
    fieldType byte
    decimals  byte
}

type mysqlRows struct {
    mc      *mysqlConn
    columns []mysqlField
}

type binaryRows struct { // 二进制结果集协议
    mysqlRows // 对于 Go 的 组合特性 应该不会陌生吧?
}

type textRows struct { // 文本结果集协议
    mysqlRows
}

func (rows *mysqlRows) Columns() []string {
    columns := make([]string, len(rows.columns))
    
    // 将列名赋值到 columns ,如果有设置别名则赋值别名...
  
    return columns
}

func (rows *mysqlRows) Close() error {
    // 将连接里面的未读数据读完,然后将连接置空
}

// 接下来的 Next 函数实现就交由 binaryRows 和 textRows 了
func (rows *binaryRows) Next(desc []driver.Value) error {
    if mc := rows.mc; mc != nil {
        if mc.netConn == nil {
            return ErrInvalidConn
        }

        return rows.readRow(dest) // 读二进制协议结果集
    }
    return io.EOF
}

func (rows *testRows) Next(desc []driver.Value) error {
    if mc := rows.mc; mc != nil {
        if mc.netConn == nil {
            return ErrInvalidConn
        }

        return rows.readRow(dest) // 读取文本协议
    }
    return io.EOF
}

可以说,实现了 driver.Rows 接口的只有 binaryRowstestRows ,而他们里面的 readRow(desc) 实现由于都是和协议强相关的代码,就不再解析了。

我们跟着源码可以看到,使用 textRows 的场景在 getSystemVar 以及 Query 中,而使用 binaryRows 的场景在 statement 中,就是我们下一步需要解析的部分。

Statement.go

Prepared Statement ,即预处理语句,他有什么优势呢,为什么 MySQL 要加入它?

  • 执行性能更高:MySQL 会对 Prepared Statement 语句预先进行编译成模板,并将 占位符 替换 参数 的位置,这样如果频繁执行一条参数只有少量替换的语句时候,性能会得到大量提高。可能有同学会有疑问,为什么 MySQL 语句还需要编译?那么可以来参考下这篇 MySQL Prepare 原理

  • 传输协议更优:Prepare Statement 在传输时候使用的是 Binary Protocol ,比使用 Text Protocol 的查询具有 传输数据量更小无需转换数据格式 等优势,缓解了 CPU网络 的开销。

  • 安全性更好:由 MySQL Prepare 原理 我们可以知道,Perpare 编译之后会生成 语法树,在执行的时候才会将参数传进来,这样就避免了平常直接执行 SQL 语句 会发生的 SQL 注入 问题。

好了,先来看下 mysqlStmt 的数据结构:

type mysqlStmt struct {
    mc          *mysqlConn
    id          uint32
    paramCount  int
    columns     []mysqlField // cached from the first query (既然SQL已经预编译好了,返回的结果集列名已经是确定的,所以在收到 PREPARE_OK 之后解析数据后会缓存下来)
}

我们发现,它比 mysqlRows 多了两个成员变量:

  • idMySQL 预处理语句之后,会给该语句分配一个 id 并返回客户端,用于:

  • 客户端提交该 id 给服务器调用对应的预处理语句。

  • paramCount :参数数量,等于 占位符 的个数,用于:

  • 判断传入的参数个数是否与预编译语句中的占位符个数一致。

  • 判断返回的 PREPARE_OK 响应报文是否带有 参数列名 数据。

下面来看看如何创建并使用一个 Prepare Statement

func (mc *mysqlConn) Prepare(query string) (driver.Stmt, error) { // 传入需要预编译的 SQL 语句
    // 检查连接是否可用...
  
    err = mc.writeCommandPacketStr(comStmtPrepare, query) // 将 SQL 发往数据库进行预编译
    if err != nil {
        return nil, err
    }

    stmt := &mysqlStmt{ // 预编译成功,先创建 stmt 对象
        mc: mc,
    }
  
    // Read Result
    columnCount, err := stmt.readPrepareResultPacket() // 从 stmt 的连接读取返回 响应报文
    if err == nil {
        if stmt.paramCount > 0 { // 如果预编译的 SQL 的有参数
            if err = mc.readUntilEOF(); err != nil { // 读取参数列名数据
                return nil, err
            }
        }
        
        if columnCount > 0 { // 返回执行结果的列表个数
            err = mc.readUntilEOF() // 读取执行结果的列名数据
        }
    }
  
    return stmt, err
}

因为是已经预编译好的语句,所以在执行的时候只需要将参数传进去就可以了。

func (stmt *mysqlStmt) Exec(args []driver.Value) (driver.Result, error) {
    // 检查连接是否可用...
  
    err := stmt.writeExecutePacket(args)
    if err != nil {
        return nil, err
    }
  
    // 读取结果集的行、列数据
}

func(stmt *mysqlStmt) writeExecutePacket(args []driver.Value) error {
    if len(args) != stmt.paramCount { // 判断传进来的参数和预编译好的SQL参数 个数是否一致
        return fmt.Errorf(
            "argument count mismatch (got: %d; has: %d)",
            len(args),
            stmt.paramCount,
        )
    }
  
    // 读取缓冲器中的数据,如果为空,则返回异常...
  
    // command [1 byte]
    data[4] = comStmtExecute

    // statement_id [4 bytes] 将预编译语句的 id 转换为 4字节的二进制数据
    data[5] = byte(stmt.id)
    data[6] = byte(stmt.id >> 8)
    data[7] = byte(stmt.id >> 16)
    data[8] = byte(stmt.id >> 24)

    // flags (0: CURSOR_TYPE_NO_CURSOR) [1 byte]
    data[9] = 0x00

    // iteration_count (uint32(1)) [4 bytes]
    data[10] = 0x01
    data[11] = 0x00
    data[12] = 0x00
    data[13] = 0x00
  
    // 将参数按照不同的类型转换为 binary protobuf 并 append 到 data 中...
  
    return mc.writePacket(data)
}

相信看到这里,已经能对看懂源码的 70% 了,剩余的代码都是和协议相关,就留待有兴趣的读者继续研究,这里就不再展开讲了。

Transaction.go

事务是 MySQL 中很重要的一部分,但是驱动的实现却很简单,因为一切的事务控制都已经交由 MySQL 去执行了,驱动所需要做的,只要发送一个 commit 或者 rollbackcommand packet 即可。

type mysqlTx struct {
    mc *mysqlConn
}

func (tx *mysqlTx) Commit() (err error) {
    if tx.mc == nil || tx.mc.netConn == nil {
        return ErrInvalidConn
    }
    err = tx.mc.exec("COMMIT")
    tx.mc = nil
    return
}

func (tx *mysqlTx) Rollback() (err error) {
    if tx.mc == nil || tx.mc.netConn == nil {
        return ErrInvalidConn
    }
    err = tx.mc.exec("ROLLBACK")
    tx.mc = nil
    return
}

总结

最后,其实 buffer 的实现对我来说印象是最深刻的,因为它是最简单而又是最有效的实现了一个消息缓冲器,它实现的巧妙让我决定把它放到第一节,而其他的几乎都和 MySQL 的协议相关,看这些源码让我对 MySQL 有了更多的认识。

好了,本篇字数比较多,也会有很多不足,希望大家能够给本篇博客多提点意见,让我可以改进的更好。如果还有机会,我会带来其他篇章的源码解析,敬请期待 :)

参考链接

相关文章

网友评论

      本文标题:go-sql-driver 源码解析

      本文链接:https://www.haomeiwen.com/subject/zlvyuxtx.html