美文网首页
go语言实现内网穿透(udp版)

go语言实现内网穿透(udp版)

作者: 今天i你好吗 | 来源:发表于2024-10-23 22:58 被阅读0次

    相关内容

    node.js实现内网穿透: https://www.jianshu.com/p/d2d4f8bff599
    kotlin实现内网穿透: https://www.jianshu.com/p/c8dc095c758e

    最大设计连接数: 65535

    前面写了个udp转tcp再转udp的工具, 打算用它和tcp内网穿透结合使用 来实现udp内网穿透, 但是在实际使用中发现存在网速较慢的问题, 初步判断为运营商网络问题(使用http下载也一样, 使用单线程只能达到1MB/S内, 3条就可以达到10MB/S. 上传没有问题). 这个问题没法解决就只好再写个udp版. 本来想用udp打洞写的, 但是有一个网络不支持... 只好用服务器转发写, 但是仍然存在一个小问题, 暂时不打算解决. 结尾会讲.

    实现代码

    服务端:

    package main
    
    import (
        "crypto/rsa"
        "crypto/sha1"
        "crypto/x509"
        "encoding/pem"
        "errors"
        "fmt"
        "log"
        "net"
        "os"
        "strings"
        "sync"
        "time"
    )
    
    var (
        privateKeyStr = "-----BEGIN PRIVATE KEY-----\n" +
            "MIICdQIBADANBgkqhkiG9w0BAQEFAASCAl8wggJbAgEAAoGBAJYBf37uy0sVXyxb\n" +
            "bMDjvQzxv/ke3UWCSkYhUd8e+MjGHeT8A9V9aVemg3qUogND/Pgtlz6bTd9p5H+q\n" +
            "OCXnrZbSwdAN7O/9x3zhRzaEOH+9OJ8vpF80DuatbdqqJXFPiDO+nfbufNyT8n+b\n" +
            "N9UXISAVv7Nay+vTEySD401czDDzAgMBAAECgYAE9u26TLrrtDxfInN5+s+R8xpQ\n" +
            "a2YVW9eLdKTaBpNjSbNJldGmqiznWrp1PyARjZl8uT2NM+Si5UVLuF19W6qSC1Tw\n" +
            "bn2brBSsFsufKQff0XXRsD8OqKT5h5/PlRATYgisZonD/v0SPfDHROcNCFiOelEv\n" +
            "kKZAJgJS3vghZx5jCQJBAMTb4esPOApK+6wXlfVShvRiac9UWW9KBYLwpqya5Cjd\n" +
            "X/QUGVwywMFkNgRnBz7yWdJd1zuXfuI77N87BXYOE68CQQDDEjfaKCyZJ73RfNJo\n" +
            "ED5DRfZXm86RPBDlZznJ4shjNVbk1sGClYAk0WuAHeIZJVpm1HpME6NSCfumqeOq\n" +
            "z1P9AkAQu+xJcgK+hT89ksexkfFc5ty9vhrYJf+v8MsKUyRgAOl+MxMwzjOqfN1G\n" +
            "pIduJ2XRRx7btvYXPybUlwzQy0OLAkBIexlznt/LXH/kOcv4TKjF2FYLAWKEhlwE\n" +
            "0REg2Xn5mtUZnE40lhYSGBoodXIQQ9fOQ37Zi6ZwkjMGHzPvwK+FAkAe9gHRMI2u\n" +
            "lzwVp/AQntBMXmw92IULIlfRmfV1jDBYuT0JHUUGZqfCrz+iDW8Ot24QBzLbwKxJ\n" +
            "JRrmXbUxObmC\n" +
            "-----END PRIVATE KEY-----"
        natDispatcherAddr = ":8989"
        natPasswords      = []string{"yzh"}
        timeOut           = time.Second * 60 * 2
        serverMap         = map[string]*Server{}
        mapMutex          sync.Mutex
        clientIndex       = 0
        maxLength         = 0
    )
    
    func main() {
        log.Println("入参: " + strings.Join(os.Args[1:], " "))
        if len(os.Args) == 2 {
            natDispatcherAddr = os.Args[1]
        }
        log.SetFlags(log.LstdFlags | log.Lshortfile)
        startServer()
    }
    
    func Start(natDispatcherAddrStr string, natPasswordsArr []string) {
        natDispatcherAddr = natDispatcherAddrStr
        natPasswords = natPasswordsArr
        startServer()
    }
    
    func startServer() {
        log.Println("NatU分发服务地址: " + natDispatcherAddr)
        go func() {
            for {
                time.Sleep(time.Second * 60)
                println()
                log.Println("natU转发Size:", len(serverMap), ",maxLength:", maxLength)
                mapMutex.Lock()
                for _, value := range serverMap {
                    log.Println("natU转发中:", value.toString())
                }
                mapMutex.Unlock()
            }
        }()
        privateKey, err := loadPrivateKey(privateKeyStr)
        if err != nil {
            log.Panicln(err)
        }
    
        listenerAddr, err := net.ResolveUDPAddr("udp", natDispatcherAddr)
        if err != nil {
            log.Println(err)
            return
        }
        network := "udp"
        if listenerAddr.IP.To4() != nil {
            network = "udp4"
        } else if listenerAddr.IP.To16() != nil {
            network = "udp6"
        }
        listenerConn, err := net.ListenUDP(network, listenerAddr)
        if err != nil {
            log.Println(err)
            return
        }
        defer listenerConn.Close()
        buffer := make([]byte, 1024*64)
        for {
            // log.Println("监听消息")
            n, clientAddr, err := listenerConn.ReadFromUDP(buffer)
            if err != nil {
                log.Println(err)
                if errors.Is(err, net.ErrClosed) {
                    break
                }
                time.Sleep(time.Second)
                continue
            }
            if n > maxLength {
                maxLength = n
            }
            if n < 1 {
                log.Println("异常消息:", clientAddr)
                continue
            }
            data := make([]byte, n)
            copy(data, buffer)
    
            cmd := data[0]
            // log.Println("收到消息:", cmd, clientAddr, "=>", listenerConn.LocalAddr())
    
            switch {
            case cmd == 1:
                // log.Println("心跳数据")
                realData := data[1:]
                // 鉴权
                value := handleNatUAuth(realData, privateKey, clientAddr, listenerConn)
                listenerConn.SetWriteDeadline(time.Now().Add(time.Second * 3))
                listenerConn.WriteToUDP([]byte{value}, clientAddr)
            case cmd == 0:
                // 真实数据
                cIndex := int(data[1])*256 + int(data[2])
                realData := data[3:]
                // log.Println("转发到:", cIndex)
                handleRealData(cIndex, realData)
            }
        }
    }
    
    func handleRealData(cIndex int, realData []byte) {
        servers := []*Server{}
        mapMutex.Lock()
        for _, server := range serverMap {
            servers = append(servers, server)
        }
        mapMutex.Unlock()
        for _, server := range servers {
            server.clientMutex.Lock()
            for _, value := range server.clientMap {
                if value.index == cIndex {
                    // log.Println("转成功:", cIndex, value.clientAddr)
                    value.openConn.SetWriteDeadline(time.Now().Add(time.Second * 3))
                    value.openConn.WriteToUDP(realData, value.clientAddr)
                    value.lastLime = time.Now()
                    break
                }
            }
            server.clientMutex.Unlock()
        }
    }
    
    func handleNatUAuth(cmdData []byte, privateKey *rsa.PrivateKey, clientAddr *net.UDPAddr, listenerConn *net.UDPConn) byte {
        // decryptedText, err := rsa.DecryptPKCS1v15(nil, privateKey, cmdData)
        decryptedText, err := rsa.DecryptOAEP(sha1.New(), nil, privateKey, cmdData, nil)
        if err != nil {
            log.Println("DecryptPKCS1v15 err", err)
            return 3
        }
        info := string(decryptedText)
        infos := strings.Split(info, "-")
        if len(infos) < 2 {
            log.Println("infos error", infos)
            return 4
        }
        version := 1
        pwdOk := false
        for _, v := range natPasswords {
            if v == infos[1] {
                version = 1
                pwdOk = true
                break
            }
        }
        for _, v := range natPasswords {
            if v == infos[0] {
                version = 2
                pwdOk = true
                break
            }
        }
        if !pwdOk {
            log.Println("密码错误", infos)
            return 5
        }
        var openAddrs []string
        switch version {
        case 1:
            openAddrs = infos[:1]
        case 2:
            openAddrs = infos[1:]
        default:
            openAddrs = infos[:1]
        }
        // log.Println("openAddrs:", openAddrs)
    
        for index, address := range openAddrs {
            server := getServer(address, index, version, clientAddr, strings.Join(openAddrs, "-"), listenerConn)
            if server == nil {
                log.Println("getServer nil")
                return 6
            }
        }
        return 1
    }
    
    func getServer(address string, index int, version int, clientAddr *net.UDPAddr, openAddrsStr string, listenerConn *net.UDPConn) *Server {
        addr := strings.Split(address, ":")
        if len(addr) < 2 {
            log.Println("address error", address)
            return nil
        }
        port := addr[len(addr)-1]
        mapMutex.Lock()
        defer mapMutex.Unlock()
        server := serverMap[port]
        if server != nil {
            if server.address != address {
                log.Println("端口重复", server.address, address)
                return nil
            }
            if server.openAddrsStr != openAddrsStr {
                log.Println("openAddrsStr不同", server.openAddrsStr, openAddrsStr)
                return nil
            }
            if server.version != version {
                log.Println("版本不同", server.version, version)
                return nil
            }
        } else {
            listenerAddr, err := net.ResolveUDPAddr("udp", address)
            if err != nil {
                log.Println(err)
                return nil
            }
            network := "udp"
            if listenerAddr.IP.To4() != nil {
                network = "udp4"
            } else if listenerAddr.IP.To16() != nil {
                network = "udp6"
            }
            openConn, err := net.ListenUDP(network, listenerAddr)
            if err != nil {
                log.Println(err)
                return nil
            }
            log.Println("开放地址:", address, "对方index:", index, "版本:", version)
            server = &Server{address: address, index: index, version: version, createTime: time.Now(),
                openConn: openConn, openAddrsStr: openAddrsStr, clientMap: map[string]*Client{}}
            serverMap[port] = server
            go server.accept(port, listenerConn)
        }
        server.clientAddr = clientAddr
        server.openConn.SetReadDeadline(time.Now().Add(timeOut))
        return server
    }
    
    type Server struct {
        openAddrsStr string
        index        int
        version      int
        createTime   time.Time
        address      string
        openConn     *net.UDPConn
        clientAddr   *net.UDPAddr
        clientMap    map[string]*Client
        clientMutex  sync.Mutex
    }
    
    type Client struct {
        index      int
        lastLime   time.Time
        clientAddr *net.UDPAddr
        openConn   *net.UDPConn
    }
    
    func (server *Server) accept(port string, listenerConn *net.UDPConn) {
        buffer := make([]byte, 1024*64)
        for {
            n, clientAddr, err := server.openConn.ReadFromUDP(buffer)
            if err != nil {
                log.Println(err)
                break
            }
            if n > maxLength {
                maxLength = n
            }
            if n < 1 {
                log.Println("空消息:", clientAddr, port)
                continue
            }
            data := make([]byte, n)
            copy(data, buffer)
            // log.Println("转发消息:", len, clientAddr, "=>", server.index)
    
            server.clientMutex.Lock()
            client := server.clientMap[clientAddr.String()]
            if client == nil {
            getClient:
                for {
                    clientIndex++
                    index := clientIndex
                    for key, value := range server.clientMap {
                        if value.index == index {
                            log.Println("index无效:", index)
                            continue getClient
                        }
                        if time.Since(value.lastLime) >= timeOut {
                            delete(server.clientMap, key)
                        }
                    }
                    client = &Client{index: index, clientAddr: clientAddr, openConn: server.openConn}
                    server.clientMap[clientAddr.String()] = client
                    break
                }
            }
            client.lastLime = time.Now()
            server.clientMutex.Unlock()
            listenerConn.SetWriteDeadline(time.Now().Add(time.Second * 3))
            listenerConn.WriteToUDP(append([]byte{byte(10 + server.index), byte(client.index / 256), byte(client.index % 256)}, data...), server.clientAddr)
        }
        server.openConn.Close()
        mapMutex.Lock()
        delete(serverMap, port)
        mapMutex.Unlock()
        log.Println("释放端口:", port)
    }
    
    func (server *Server) toString() string {
        ms := time.Since(server.createTime).Milliseconds()
        s := ms / 1000
        m := s / 60
        h := m / 60
        runTime := fmt.Sprintf("%d天%d时%d分%d秒", h/24, h%24, m%60, s%60)
        return fmt.Sprintf("%s=>%s, index: %d, version: %d, %s", server.address, server.clientAddr, server.index, server.version, runTime)
    }
    
    func loadPrivateKey(privateKeyStr string) (privateKey *rsa.PrivateKey, err error) {
        block, _ := pem.Decode([]byte(privateKeyStr))
        if block == nil {
            return nil, fmt.Errorf("解码私钥失败")
        }
        key, err := x509.ParsePKCS8PrivateKey(block.Bytes)
        if err != nil {
            return nil, err
        }
        privateKey, ok := key.(*rsa.PrivateKey)
        if !ok {
            return nil, fmt.Errorf("非法私钥文件")
        }
        return privateKey, nil
    }
    
    

    客户端:

    package main
    
    import (
        "crypto/rand"
        "crypto/rsa"
        "crypto/sha1"
        "crypto/x509"
        "encoding/pem"
        "errors"
        "fmt"
        "log"
        "net"
        "os"
        "strings"
        "sync"
        "time"
    )
    
    var (
        publicKeyStr = "-----BEGIN PUBLIC KEY-----\nMIGfMA0GCSqGSIb3DQEBAQUAA4GNADCBiQKBgQCWAX9+7stLFV8sW2zA470M8b/5\nHt1FgkpGIVHfHvjIxh3k/APVfWlXpoN6lKIDQ/z4LZc+m03faeR/qjgl562W0sHQ\nDezv/cd84Uc2hDh/vTifL6RfNA7mrW3aqiVxT4gzvp327nzck/J/mzfVFyEgFb+z\nWsvr0xMkg+NNXMww8wIDAQAB\n-----END PUBLIC KEY-----\n"
    
        natDispatcherAddr = "127.0.0.1:8989"
        natPassword       = "yzh"
        rateInterval      = time.Second * 15
        timeOut           = time.Second * 60 * 2
        natMapArr         = []string{
            ":1701-192.168.3.25:1701",
            ":11771-127.0.0.1:1771",
            ":11772-127.0.0.1:1772",
            ":11773-127.0.0.1:1773",
        }
        errMap = map[int]string{
            2: "config error or port used",
            3: "DecryptPKCS1v15 err",
            4: "infos error",
            5: "密码错误",
            6: "getServer nil",
            7: "",
            8: "",
            9: "",
        }
        listenerConn *net.UDPConn
        serverAddr   *net.UDPAddr
        mapMutex     sync.Mutex
        forwardMap   = map[int]*net.UDPConn{}
        natSuccess   = false
        maxLength    = 0
    )
    
    func main() {
        log.Println("入参: " + strings.Join(os.Args[1:], " "))
        if len(os.Args) == 2 {
            natDispatcherAddr = os.Args[1]
        }
        log.SetFlags(log.LstdFlags | log.Lshortfile)
        startClient()
    }
    
    func Start(natDispatcherAddrStr string, natPasswordStr string, natMapStr []string) {
        natDispatcherAddr = natDispatcherAddrStr
        natPassword = natPasswordStr
        natMapArr = natMapStr
        startClient()
    }
    
    func startClient() {
        log.Println("NatU分发服务地址: " + natDispatcherAddr)
        publicKey, err := loadPublicKey(publicKeyStr)
        if err != nil {
            log.Panicln(err)
        }
        natParams := []string{natPassword}
        // 通过index找到需要转发的位置
        localServerAddrs := []string{}
        for _, v := range natMapArr {
            mapArr := strings.Split(v, "-")
            natServerOpenAddr := mapArr[0]
            localServerAddr := mapArr[1]
            log.Println("NatU服务器开放地址: "+natServerOpenAddr, "本地服务地址: "+localServerAddr)
            natParams = append(natParams, natServerOpenAddr)
            localServerAddrs = append(localServerAddrs, localServerAddr)
        }
        //密码, 开放端口
        natInfo, err := rsa.EncryptOAEP(sha1.New(), rand.Reader, publicKey, []byte(strings.Join(natParams, "-")), nil)
        if err != nil {
            log.Panicln(err)
        }
        go func() {
            for {
                log.Println("startNat")
                creatClient(localServerAddrs)
                listenerConn = nil
                serverAddr = nil
                natSuccess = false
                time.Sleep(time.Second)
            }
        }()
    
        go func() {
            for {
                if listenerConn == nil || serverAddr == nil {
                    time.Sleep(time.Second)
                    continue
                }
                // log.Println("主动发消息:", listenerConn.LocalAddr(), "=>", serverAddr)
                data := append([]byte{1}, natInfo...)
                listenerConn.SetWriteDeadline(time.Now().Add(time.Second * 5))
                listenerConn.WriteToUDP(data, serverAddr)
                time.Sleep(rateInterval)
            }
        }()
        for {
            time.Sleep(time.Second * 60)
            println()
            log.Println("natU转发Size:", len(natMapArr), ",maxLength:", maxLength, ",natSuccess:", natSuccess)
            for _, value := range natMapArr {
                log.Println("natU转发中:", strings.ReplaceAll(value, "-", "=>"))
            }
        }
    }
    
    func creatClient(localServerAddrs []string) {
        var err error
        serverAddr, err = net.ResolveUDPAddr("udp", natDispatcherAddr)
        if err != nil {
            log.Println(err)
            return
        }
        listenerAddr, err := net.ResolveUDPAddr("udp", ":0")
        if err != nil {
            log.Println(err)
            return
        }
        listenerConn, err = net.ListenUDP("udp", listenerAddr)
        if err != nil {
            log.Println(err)
            return
        }
        defer listenerConn.Close()
        buffer := make([]byte, 1024*64)
        for {
            listenerConn.SetReadDeadline(time.Now().Add(timeOut))
            n, clientAddr, err := listenerConn.ReadFromUDP(buffer)
            if err != nil {
                log.Println(err)
                break
            }
            if n > maxLength {
                maxLength = n
            }
            if n < 1 {
                log.Println("异常消息:", clientAddr)
                continue
            }
            data := make([]byte, n)
            copy(data, buffer)
    
            // if clientAddr.Port != serverAddr.Port || clientAddr.IP.String() != serverAddr.IP.String() {
            if clientAddr.Port != serverAddr.Port {
                log.Println("异常消息:", serverAddr, clientAddr)
                continue
            }
            cmd := data[0]
            // log.Println("收到响应:", cmd, clientAddr, "=>", listenerConn.LocalAddr())
            switch {
            case cmd == 1:
                if !natSuccess {
                    natSuccess = true
                    log.Println("natU建立成功")
                }
                // log.Println("心跳数据")
            case cmd > 1 && cmd < 10:
                errMsg := errMap[int(cmd)]
                if errMsg == "" {
                    errMsg = "config error"
                }
                log.Panicln(cmd, errMsg, natMapArr)
            case cmd >= 10:
                forwardAddress := localServerAddrs[data[0]-10]
                clinetIndex := int(data[1])*256 + int(data[2])
                realData := data[3:]
                // log.Println("转发到:", forwardAddress)
                handleClientRequest(clientAddr, realData, forwardAddress, clinetIndex)
            }
        }
    }
    
    func handleClientRequest(clientAddr *net.UDPAddr, clientData []byte, forwardAddress string, clinetIndex int) {
        if clientAddr == nil {
            return
        }
        clientAddrString := clientAddr.String()
        mapMutex.Lock()
        defer mapMutex.Unlock()
        forwardConn := forwardMap[clinetIndex]
        if forwardConn == nil {
            forwardAddr, err := net.ResolveUDPAddr("udp", forwardAddress)
            if err != nil {
                log.Println(err)
                return
            }
            forwardConn, err = net.DialUDP("udp", nil, forwardAddr)
            if err != nil {
                log.Println(err)
                return
            }
            infoStr := clientAddrString + "=>" + forwardAddress + "=>" + forwardConn.LocalAddr().String() + "=>" + forwardConn.RemoteAddr().String()
            log.Println("添加udp转发:" + infoStr)
            forwardMap[clinetIndex] = forwardConn
            buffer := make([]byte, 1024*64)
            go func() {
                defer forwardConn.Close()
                forwardSuccess := false
                for {
                    forwardConn.SetReadDeadline(time.Now().Add(timeOut))
                    n, serverAddr, err := forwardConn.ReadFromUDP(buffer)
                    if err != nil {
                        log.Println(err)
                        if errors.Is(err, net.ErrClosed) {
                            break
                        }
                        if nerr, ok := err.(net.Error); ok && nerr.Timeout() {
                            break
                        }
                        time.Sleep(time.Second)
                        continue
                    }
                    // if serverAddr.Port != forwardAddr.Port || serverAddr.IP.String() != forwardAddr.IP.String() {
                    if serverAddr.Port != forwardAddr.Port {
                        log.Println("异常消息:", serverAddr.String(), forwardAddr.String())
                        continue
                    }
                    if n > maxLength {
                        maxLength = n
                    }
                    if !forwardSuccess {
                        forwardSuccess = true
                        log.Println("udp转发成功:", serverAddr.String(), n, clientAddrString)
                    }
                    // log.Println("服务端消息:", serverAddr.String(), len, clientAddrString)
                    data := make([]byte, n)
                    copy(data, buffer)
                    listenerConn.SetWriteDeadline(time.Now().Add(time.Second * 5))
                    listenerConn.WriteToUDP(append([]byte{0, byte(clinetIndex / 256), byte(clinetIndex % 256)}, data...), clientAddr)
                }
                log.Println("移除udp:" + infoStr)
                mapMutex.Lock()
                delete(forwardMap, clinetIndex)
                mapMutex.Unlock()
            }()
        }
        // log.Println("客户端消息:", clientAddrString, len(clientData))
        forwardConn.SetWriteDeadline(time.Now().Add(time.Second * 5))
        forwardConn.Write(clientData)
    }
    
    func loadPublicKey(publicKeyStr string) (publicKey *rsa.PublicKey, err error) {
        block, _ := pem.Decode([]byte(publicKeyStr))
        if block == nil {
            return nil, fmt.Errorf("解码公钥失败")
        }
        key, err := x509.ParsePKIXPublicKey(block.Bytes)
        if err != nil {
            return nil, err
        }
        publicKey, ok := key.(*rsa.PublicKey)
        if !ok {
            return nil, fmt.Errorf("非法公钥文件")
        }
        return publicKey, nil
    }
    
    
    

    存在的问题:

    客户端fd00::2给fd00::5发消息 服务端知道是fd00::2发来的, 但是不知道是哪个ip接收的, 也无法控制使用哪个ip回消息, 测试中发现服务端可能会用fd00::6发消息给fd00::2, 在部分网络下这个消息是发送不过去的(这也是我没用打洞法的原因), 问题点就在这里. 解决方案也很简单, 分别监听每个ip, 但是需要监听设备ip的变化, 不想这样做. 不知道有没有大佬有更好的解决方案

    画个草图好理解些:

    image.png

    相关文章

      网友评论

          本文标题:go语言实现内网穿透(udp版)

          本文链接:https://www.haomeiwen.com/subject/nlxjdjtx.html