整体思路
借助protoc,根据Proto文件生成支持自动降级的client文件
具体实现
generate.go
type BaseGenerator struct {
Pkg []*models.File
Gen *protogen.Plugin
}
func GeneratorFactory(files []*models.File, gen *protogen.Plugin) []Generator {
return []Generator{&MTLSGenerator{BaseGenerator: BaseGenerator{files, gen}}}
}
type Generator interface {
Generate() error
}
type MTLSGenerator struct {
BaseGenerator
}
func (m *MTLSGenerator) Generate() error {
for index := range m.Pkg {
p := m.Pkg[index]
if len(p.Services) == 0 {
continue
}
g := m.Gen.NewGeneratedFile(getFilePath(p.FilenamePrefix), "")
g.P("// Code generated by protoc-gen-mtls. DO NOT EDIT.")
g.P()
g.P("package ", p.GoPackageName)
g.P()
g.P("import (")
g.P(" ", "context \"context\"")
g.P()
g.P(" ", "grpc \"google.golang.org/grpc\"")
g.P()
needImport := map[string]string{}
for i, imps := 0, p.Imports; i < imps.Len(); i++ {
imp := p.Imports.Get(i)
impFile, ok := m.Gen.FilesByPath[imp.Path()]
if !ok {
continue
}
if string(impFile.GoImportPath) == p.GoImportPath {
// Don't generate imports or aliases for types in the same Go package.
continue
}
if !imp.IsWeak {
needImport[string(imp.Name())] = fmt.Sprintf(`"%s"`, string(impFile.GoImportPath))
}
}
for key, value := range needImport {
if p.NeedImport[key] {
g.P(" ", key, " ", value)
}
}
g.P(")")
g.P()
for _, service := range p.Services {
g.P("type ", getFirstLowServiceName(service.Name), "MTLSClient", " struct{")
g.P(" mtlsClient ", service.Name, "Client")
g.P(" tcpClient ", service.Name, "Client")
g.P("}")
g.P()
g.P("func", " New", service.Name, "MTLSClient", "(mtls,tcp grpc.ClientConnInterface) ", service.Name, "Client", " {")
g.P(" ", "if mtls == nil {")
g.P(" ", "return New", service.Name, "Client(tcp)")
g.P(" ", "}")
g.P(" ", "return &", getFirstLowServiceName(service.Name), "MTLSClient{", getServiceName(service.Name, "mtls"), ",",
getServiceName(service.Name, "tcp"), "}")
g.P("}")
g.P()
for _, method := range service.Methods {
g.P("func (c *", getFirstLowServiceName(service.Name), "MTLSClient) ", method.Name, "(ctx context.Context, in *", getRequest(p.GoPackageName, method.Req), ", opts ...grpc.CallOption) (*",
getRequest(p.GoPackageName, method.Resp), ", error) {")
g.P(" ", "resp, err := c.mtlsClient.", method.Name, "(ctx, in, opts...)")
g.P(" ", "if err!=nil {")
g.P(" ", "return c.tcpClient.", method.Name, "(ctx, in, opts...)")
g.P(" ", "}")
g.P(" ", "return resp,err")
g.P("}")
}
}
}
return nil
}
func getRequest(packageName, req string) string {
br := strings.Split(req, ".")
if len(br) < 2 {
return br[0]
}
if br[len(br)-2] == packageName {
return br[len(br)-1]
}
return br[len(br)-2] + "." + br[len(br)-1]
}
func getServiceName(name, key string) string {
return fmt.Sprintf("New%sClient(%s)", name, key)
}
func getFirstLowServiceName(name string) string {
n := strings.ToLower(name[:1])
return n + name[1:]
}
func getFilePath(prefix string) string {
return fmt.Sprintf("%s_mtls.pb.go", prefix)
}
parse
func getMessageMap(gen *protogen.Plugin) map[string]*protogen.Message {
messageMap := make(map[string]*protogen.Message)
for _, f := range gen.Files {
for _, message := range f.Messages {
name := string(message.Desc.FullName())
messageMap[name] = message
}
}
return messageMap
}
func parseFile(gen *protogen.Plugin) ([]*models.File, error) {
var files []*models.File
messageMap := getMessageMap(gen)
codeMap := make(map[string]string)
for _, f := range gen.Files {
var services []*models.Service
prefix := f.GeneratedFilenamePrefix
importPath := string(f.GoImportPath)
goPackageName := string(f.GoPackageName)
protoPackage := *f.Proto.Package
file := &models.File{
FilenamePrefix: prefix,
GoImportPath: importPath,
GoPackageName: goPackageName,
ProtoPackageName: protoPackage,
Dependency: f.Proto.Dependency,
Imports: f.Desc.Imports(),
}
mp := map[string]bool{}
for _, service := range f.Services {
name := service.GoName
svc := &models.Service{Name: name}
method, m, err := parseMethod(string(f.GoPackageName), service, codeMap, messageMap)
if err != nil {
return nil, err
}
mp = mergeMap(mp, m)
svc.Methods = method
services = append(services, svc)
}
file.NeedImport = mp
file.Services = services
files = append(files, file)
}
return files, nil
}
func mergeMap(n, o map[string]bool) map[string]bool {
for key, value := range o {
n[key] = value
}
return n
}
func parseMethod(name string, service *protogen.Service, codeMap map[string]string, messageMap map[string]*protogen.Message) ([]*models.Method, map[string]bool, error) {
var methods []*models.Method
var needImport = map[string]bool{}
for _, method := range service.Methods {
m := &models.Method{Name: method.GoName}
m.Req = string(method.Input.Desc.FullName())
m.Resp = string(method.Output.Desc.FullName())
if !strings.Contains(string(method.Input.Desc.FullName()), name) {
needImport[getImportName(string(method.Input.Desc.FullName()))] = true
}
if !strings.Contains(string(method.Output.Desc.FullName()), name) {
needImport[getImportName(string(method.Output.Desc.FullName()))] = true
}
methods = append(methods, m)
}
return methods, needImport, nil
}
func getImportName(i string) string {
r := strings.Split(i, ".")
return r[len(r)-2]
}
网友评论