golang源码分析:go-mysql(5)fake server

Golang
195
0
0
2024-01-19

如何定义一个Fake server,接受客户端的请求,返回希望的结果,本质上是一个tcp server服务器,定义一个服务器过程如下:

  l, err := net.Listen("tcp", "127.0.0.1:4000")
  c, err := l.Accept()
  conn, err := server.NewConn(c, "root", "", server.EmptyHandler{})
  for {
    if err := conn.HandleCommand(); err != nil {
      log.Fatal(err)
    }
  }

建立连接相关代码位于:github.com/go-mysql-org/go-mysql@v1.7.0/server/conn.go

func NewConn(conn net.Conn, user string, password string, h Handler) (*Conn, error) {
  p := NewInMemoryProvider()
  p.AddUser(user, password)
  packetConn = packet.NewConn(conn)
  c := &Conn{
    Conn:               packetConn,
    serverConf:         defaultServer,
    credentialProvider: p,
    h:                  h,
    connectionID:       atomic.AddUint32(&baseConnID, 1),
    stmts:              make(map[uint32]*Stmt),
    salt:               RandomBuf(20),
  }
  if err := c.handshake(); err != nil {

其中的InMemoryProvider是一个sync.Map

// implements a in memory credential provider
type InMemoryProvider struct {
  userPool sync.Map // username -> password
}

创建连接的过程如下:

func NewConn(conn net.Conn) *Conn {
  c := new(Conn)
  c.Conn = conn


  c.bufPool = NewBufPool()
  c.br = bufio.NewReaderSize(c, 65536) // 64kb
  c.reader = c.br


  c.copyNBuf = make([]byte, 16*1024)


  return c
}

返回的Conn代表了一个mysql连接:

type Conn struct {
  *packet.Conn


  serverConf     *Server
  capability     uint32
  charset        uint8
  authPluginName string
  attributes     map[string]string
  connectionID   uint32
  status         uint16
  warnings       uint16
  salt           []byte // should be 8 + 12 for auth-plugin-data-part-1 and auth-plugin-data-part-2


  credentialProvider  CredentialProvider
  user                string
  password            string
  cachingSha2FullAuth bool


  h Handler


  stmts  map[uint32]*Stmt
  stmtID uint32


  closed sync2.AtomicBool
}

接着就是握手

func (c *Conn) handshake() error {
  if err := c.writeInitialHandshake(); err != nil {
    return err
  }
  if err := c.readHandshakeResponse(); err != nil {

然后进入了连接的请求和返回的处理流程github.com/go-mysql-org/go-mysql@v1.7.0/server/command.go,处理请求前,我们需要注册请求的处理器,默认实现了一个空的处理器,以及一个空的复制请求处理器。

type EmptyHandler struct {
}
func (h EmptyHandler) UseDB(dbName string) error {
  return nil
}
type EmptyReplicationHandler struct {
  EmptyHandler
}

它们实现的接口定义如下,包含7个函数,应对处理mysql的7个命令:

type Handler interface {
  //handle COM_INIT_DB command, you can check whether the dbName is valid, or other.
  UseDB(dbName string) error
  //handle COM_QUERY command, like SELECT, INSERT, UPDATE, etc...
  //If Result has a Resultset (SELECT, SHOW, etc...), we will send this as the response, otherwise, we will send Result
  HandleQuery(query string) (*Result, error)
  //handle COM_FILED_LIST command
  HandleFieldList(table string, fieldWildcard string) ([]*Field, error)
  //handle COM_STMT_PREPARE, params is the param number for this statement, columns is the column number
  //context will be used later for statement execute
  HandleStmtPrepare(query string) (params int, columns int, context interface{}, err error)
  //handle COM_STMT_EXECUTE, context is the previous one set in prepare
  //query is the statement prepare query, and args is the params for this statement
  HandleStmtExecute(context interface{}, query string, args []interface{}) (*Result, error)
  //handle COM_STMT_CLOSE, context is the previous one set in prepare
  //this handler has no response
  HandleStmtClose(context interface{}) error
  //handle any other command that is not currently handled by the library,
  //default implementation for this method will return an ER_UNKNOWN_ERROR
  HandleOtherCommand(cmd byte, data []byte) error
}

复制请求处理器接口也包含了3个接口,注册slave,dump binlog和dump GTID

type ReplicationHandler interface {
  // handle Replication command
  HandleRegisterSlave(data []byte) error
  HandleBinlogDump(pos Position) (*replication.BinlogStreamer, error)
  HandleBinlogDumpGTID(gtidSet *MysqlGTIDSet) (*replication.BinlogStreamer, error)
}

请求处理循环是server的核心,上面的处理器函数,被镶嵌在这个处理循环中:

func (c *Conn) HandleCommand() error {
data, err := c.ReadPacket()
v := c.dispatch(data)
err = c.WriteValue(v)
if c.Conn != nil {
    c.ResetSequence()
if err != nil {
    c.Close()

它不断接受packet然后进行分发给处理器,最后将结果写入返回缓冲区。github.com/go-mysql-org/go-mysql@v1.7.0/packet/conn.go

func (c *Conn) ReadPacket() ([]byte, error) {
  return c.ReadPacketReuseMem(nil)
}

数据会读入缓冲区:

func (c *Conn) ReadPacketReuseMem(dst []byte) ([]byte, error) {
        buf := utils.BytesBufferGet()
        if err := c.ReadPacketTo(buf); err != nil {
        if len(dst) > 0 {
    result = append(dst, readBytes...)

分发的过程中,按照mysql的command来进行分发,然后调用处理器的函数来进行处理。

func (c *Conn) dispatch(data []byte) interface{} {
  switch cmd {
  case COM_QUIT:
    c.Close()
    c.Conn = nil
    return noResponse{}
  case COM_QUERY:
    if r, err := c.h.HandleQuery(hack.String(data)); err != nil {
      return err
    } else {
      return r
    }
  case COM_PING:
    return nil
  case COM_INIT_DB:
    if err := c.h.UseDB(hack.String(data)); err != nil {
      return err
    } else {
      return nil
    }
  case COM_FIELD_LIST:
    index := bytes.IndexByte(data, 0x00)
    table := hack.String(data[0:index])
    wildcard := hack.String(data[index+1:])


    if fs, err := c.h.HandleFieldList(table, wildcard); err != nil {
      return err
    } else {
      return fs
    }
  case COM_STMT_PREPARE:
    c.stmtID++
    st := new(Stmt)
    st.ID = c.stmtID
    st.Query = hack.String(data)
    var err error
    if st.Params, st.Columns, st.Context, err = c.h.HandleStmtPrepare(st.Query); err != nil {
      return err
    } else {
      st.ResetParams()
      c.stmts[c.stmtID] = st
      return st
    }
  case COM_STMT_EXECUTE:
    if r, err := c.handleStmtExecute(data); err != nil {
      return err
    } else {
      return r
    }
  case COM_STMT_CLOSE:
    if err := c.handleStmtClose(data); err != nil {
      return err
    }
    return noResponse{}
  case COM_STMT_SEND_LONG_DATA:
    if err := c.handleStmtSendLongData(data); err != nil {
      return err
    }
    return noResponse{}
  case COM_STMT_RESET:
    if r, err := c.handleStmtReset(data); err != nil {
      return err
    } else {
      return r
    }
  case COM_SET_OPTION:
    if err := c.h.HandleOtherCommand(cmd, data); err != nil {
      return err
    }

    return eofResponse{}
  case COM_REGISTER_SLAVE:
    if h, ok := c.h.(ReplicationHandler); ok {
      return h.HandleRegisterSlave(data)
    } else {
      return c.h.HandleOtherCommand(cmd, data)
    }
  case COM_BINLOG_DUMP:
    if h, ok := c.h.(ReplicationHandler); ok {
      pos, err := parseBinlogDump(data)
      if err != nil {
        return err
      }
      if s, err := h.HandleBinlogDump(pos); err != nil {
        return err
      } else {
        return s
      }
    } else {
      return c.h.HandleOtherCommand(cmd, data)
    }
  case COM_BINLOG_DUMP_GTID:
    if h, ok := c.h.(ReplicationHandler); ok {
      gtidSet, err := parseBinlogDumpGTID(data)
      if err != nil {
        return err
      }
      if s, err := h.HandleBinlogDumpGTID(gtidSet); err != nil {
        return err
      } else {
        return s
      }
    } else {
      return c.h.HandleOtherCommand(cmd, data)
    }
  default:
    return c.h.HandleOtherCommand(cmd, data)

最后来到了,结果的返回流程:github.com/go-mysql-org/go-mysql@v1.7.0/server/resp.go,根据不同的数据类型进行序列化,返回输出缓冲区:

  func (c *Conn) WriteValue(value interface{}) error {
        switch v := value.(type) {
  case noResponse:
    return nil
  case eofResponse:
    return c.writeEOF()
  case error:
    return c.writeError(v)
  case nil:
    return c.writeOK(nil)
  case *Result:
    if v != nil && v.Resultset != nil {
      return c.writeResultset(v.Resultset)
    } else {
      return c.writeOK(v)
    }
  case []*Field:
    return c.writeFieldList(v, nil)
  case []FieldValue:
    return c.writeFieldValues(v)
  case *replication.BinlogStreamer:
    return c.writeBinlogEvents(v)
  case *Stmt:
    return c.writePrepare(v)
  default:
    return fmt.Errorf("invalid response type %T", value)
  }
}

github.com/go-mysql-org/go-mysql@v1.7.0/server/stmt.go

func (c *Conn) writePrepare(s *Stmt) error {
  data := make([]byte, 4, 128)


  //status ok
  data = append(data, 0)
  //stmt id
  data = append(data, Uint32ToBytes(s.ID)...)
  //number columns
  data = append(data, Uint16ToBytes(uint16(s.Columns))...)
  //number params
  data = append(data, Uint16ToBytes(uint16(s.Params))...)
  //filter [00]
  data = append(data, 0)
  //warning count
  data = append(data, 0, 0)


  if err := c.WritePacket(data); err != nil {
    return err
  }


  if s.Params > 0 {
    for i := 0; i < s.Params; i++ {
      data = data[0:4]
      data = append(data, paramFieldData...)


      if err := c.WritePacket(data); err != nil {
        return errors.Trace(err)
      }
    }


    if err := c.writeEOF(); err != nil {
      return err
    }
  }


  if s.Columns > 0 {
    for i := 0; i < s.Columns; i++ {
      data = data[0:4]
      data = append(data, columnFieldData...)


      if err := c.WritePacket(data); err != nil {
        return errors.Trace(err)
      }
    }


    if err := c.writeEOF(); err != nil {
      return err
    }
  }
  return nil
}