美文网首页
golang grpc mtls=>tcp自动服务降级

golang grpc mtls=>tcp自动服务降级

作者: 万万没想到367 | 来源:发表于2022-07-19 16:10 被阅读0次

    整体思路

    借助protoc,根据Proto文件生成支持自动降级的client文件

    具体实现

    generate.go

    type BaseGenerator struct {
        Pkg []*models.File
        Gen *protogen.Plugin
    }
    
    func GeneratorFactory(files []*models.File, gen *protogen.Plugin) []Generator {
        return []Generator{&MTLSGenerator{BaseGenerator: BaseGenerator{files, gen}}}
    }
    
    type Generator interface {
        Generate() error
    }
    
    type MTLSGenerator struct {
        BaseGenerator
    }
    
    func (m *MTLSGenerator) Generate() error {
        for index := range m.Pkg {
            p := m.Pkg[index]
            if len(p.Services) == 0 {
                continue
            }
            g := m.Gen.NewGeneratedFile(getFilePath(p.FilenamePrefix), "")
            g.P("// Code generated by protoc-gen-mtls. DO NOT EDIT.")
            g.P()
            g.P("package ", p.GoPackageName)
            g.P()
            g.P("import (")
            g.P("  ", "context \"context\"")
            g.P()
            g.P("  ", "grpc \"google.golang.org/grpc\"")
            g.P()
            needImport := map[string]string{}
            for i, imps := 0, p.Imports; i < imps.Len(); i++ {
                imp := p.Imports.Get(i)
                impFile, ok := m.Gen.FilesByPath[imp.Path()]
                if !ok {
                    continue
                }
                if string(impFile.GoImportPath) == p.GoImportPath {
                    // Don't generate imports or aliases for types in the same Go package.
                    continue
                }
                if !imp.IsWeak {
                    needImport[string(imp.Name())] = fmt.Sprintf(`"%s"`, string(impFile.GoImportPath))
                }
            }
            for key, value := range needImport {
                if p.NeedImport[key] {
                    g.P("  ", key, " ", value)
                }
            }
            g.P(")")
            g.P()
            for _, service := range p.Services {
                g.P("type ", getFirstLowServiceName(service.Name), "MTLSClient", " struct{")
                g.P("  mtlsClient ", service.Name, "Client")
                g.P("  tcpClient ", service.Name, "Client")
                g.P("}")
                g.P()
                g.P("func", " New", service.Name, "MTLSClient", "(mtls,tcp grpc.ClientConnInterface) ", service.Name, "Client", " {")
                g.P("  ", "if mtls == nil {")
                g.P("    ", "return New", service.Name, "Client(tcp)")
                g.P("  ", "}")
                g.P("  ", "return &", getFirstLowServiceName(service.Name), "MTLSClient{", getServiceName(service.Name, "mtls"), ",",
                    getServiceName(service.Name, "tcp"), "}")
                g.P("}")
                g.P()
                for _, method := range service.Methods {
                    g.P("func (c *", getFirstLowServiceName(service.Name), "MTLSClient) ", method.Name, "(ctx context.Context, in *", getRequest(p.GoPackageName, method.Req), ", opts ...grpc.CallOption) (*",
                        getRequest(p.GoPackageName, method.Resp), ", error) {")
                    g.P("  ", "resp, err := c.mtlsClient.", method.Name, "(ctx, in, opts...)")
                    g.P("  ", "if err!=nil {")
                    g.P("    ", "return c.tcpClient.", method.Name, "(ctx, in, opts...)")
                    g.P("  ", "}")
                    g.P("  ", "return resp,err")
                    g.P("}")
                }
            }
        }
        return nil
    }
    
    func getRequest(packageName, req string) string {
        br := strings.Split(req, ".")
        if len(br) < 2 {
            return br[0]
        }
        if br[len(br)-2] == packageName {
            return br[len(br)-1]
        }
    
        return br[len(br)-2] + "." + br[len(br)-1]
    }
    
    func getServiceName(name, key string) string {
        return fmt.Sprintf("New%sClient(%s)", name, key)
    }
    
    func getFirstLowServiceName(name string) string {
        n := strings.ToLower(name[:1])
        return n + name[1:]
    }
    
    func getFilePath(prefix string) string {
        return fmt.Sprintf("%s_mtls.pb.go", prefix)
    }
    

    parse

    func getMessageMap(gen *protogen.Plugin) map[string]*protogen.Message {
        messageMap := make(map[string]*protogen.Message)
        for _, f := range gen.Files {
            for _, message := range f.Messages {
                name := string(message.Desc.FullName())
                messageMap[name] = message
            }
        }
        return messageMap
    }
    
    func parseFile(gen *protogen.Plugin) ([]*models.File, error) {
        var files []*models.File
        messageMap := getMessageMap(gen)
        codeMap := make(map[string]string)
        for _, f := range gen.Files {
            var services []*models.Service
            prefix := f.GeneratedFilenamePrefix
            importPath := string(f.GoImportPath)
            goPackageName := string(f.GoPackageName)
            protoPackage := *f.Proto.Package
            file := &models.File{
                FilenamePrefix:   prefix,
                GoImportPath:     importPath,
                GoPackageName:    goPackageName,
                ProtoPackageName: protoPackage,
                Dependency:       f.Proto.Dependency,
                Imports:          f.Desc.Imports(),
            }
            mp := map[string]bool{}
            for _, service := range f.Services {
                name := service.GoName
                svc := &models.Service{Name: name}
                method, m, err := parseMethod(string(f.GoPackageName), service, codeMap, messageMap)
                if err != nil {
                    return nil, err
                }
                mp = mergeMap(mp, m)
                svc.Methods = method
                services = append(services, svc)
            }
            file.NeedImport = mp
            file.Services = services
            files = append(files, file)
        }
        return files, nil
    }
    
    func mergeMap(n, o map[string]bool) map[string]bool {
        for key, value := range o {
            n[key] = value
        }
        return n
    }
    
    func parseMethod(name string, service *protogen.Service, codeMap map[string]string, messageMap map[string]*protogen.Message) ([]*models.Method, map[string]bool, error) {
        var methods []*models.Method
        var needImport = map[string]bool{}
        for _, method := range service.Methods {
            m := &models.Method{Name: method.GoName}
            m.Req = string(method.Input.Desc.FullName())
            m.Resp = string(method.Output.Desc.FullName())
            if !strings.Contains(string(method.Input.Desc.FullName()), name) {
                needImport[getImportName(string(method.Input.Desc.FullName()))] = true
            }
            if !strings.Contains(string(method.Output.Desc.FullName()), name) {
                needImport[getImportName(string(method.Output.Desc.FullName()))] = true
            }
            methods = append(methods, m)
        }
        return methods, needImport, nil
    }
    
    func getImportName(i string) string {
        r := strings.Split(i, ".")
        return r[len(r)-2]
    }
    
    

    相关文章

      网友评论

          本文标题:golang grpc mtls=>tcp自动服务降级

          本文链接:https://www.haomeiwen.com/subject/kkqsirtx.html