美文网首页
go语言实现内网穿透

go语言实现内网穿透

作者: 今天i你好吗 | 来源:发表于2024-02-13 16:40 被阅读0次

    相关内容

    node.js实现内网穿透: https://www.jianshu.com/p/d2d4f8bff599
    kotlin实现内网穿透: https://www.jianshu.com/p/c8dc095c758e
    可以和node.js, kotlin版混用
    使用方式见node.js版,大同小异,部分逻辑稍有调整

    实现代码

    服务端:

    package main
    
    import (
        "crypto/rsa"
        "crypto/sha1"
        "crypto/x509"
        "encoding/pem"
        "fmt"
        "io"
        "log"
        "net"
        "os"
        "strings"
        "sync"
        "time"
    )
    
    var (
        privateKeyStr = "-----BEGIN PRIVATE KEY-----\n" +
            "MIICdQIBADANBgkqhkiG9w0BAQEFAASCAl8wggJbAgEAAoGBAJYBf37uy0sVXyxb\n" +
            "bMDjvQzxv/ke3UWCSkYhUd8e+MjGHeT8A9V9aVemg3qUogND/Pgtlz6bTd9p5H+q\n" +
            "OCXnrZbSwdAN7O/9x3zhRzaEOH+9OJ8vpF80DuatbdqqJXFPiDO+nfbufNyT8n+b\n" +
            "N9UXISAVv7Nay+vTEySD401czDDzAgMBAAECgYAE9u26TLrrtDxfInN5+s+R8xpQ\n" +
            "a2YVW9eLdKTaBpNjSbNJldGmqiznWrp1PyARjZl8uT2NM+Si5UVLuF19W6qSC1Tw\n" +
            "bn2brBSsFsufKQff0XXRsD8OqKT5h5/PlRATYgisZonD/v0SPfDHROcNCFiOelEv\n" +
            "kKZAJgJS3vghZx5jCQJBAMTb4esPOApK+6wXlfVShvRiac9UWW9KBYLwpqya5Cjd\n" +
            "X/QUGVwywMFkNgRnBz7yWdJd1zuXfuI77N87BXYOE68CQQDDEjfaKCyZJ73RfNJo\n" +
            "ED5DRfZXm86RPBDlZznJ4shjNVbk1sGClYAk0WuAHeIZJVpm1HpME6NSCfumqeOq\n" +
            "z1P9AkAQu+xJcgK+hT89ksexkfFc5ty9vhrYJf+v8MsKUyRgAOl+MxMwzjOqfN1G\n" +
            "pIduJ2XRRx7btvYXPybUlwzQy0OLAkBIexlznt/LXH/kOcv4TKjF2FYLAWKEhlwE\n" +
            "0REg2Xn5mtUZnE40lhYSGBoodXIQQ9fOQ37Zi6ZwkjMGHzPvwK+FAkAe9gHRMI2u\n" +
            "lzwVp/AQntBMXmw92IULIlfRmfV1jDBYuT0JHUUGZqfCrz+iDW8Ot24QBzLbwKxJ\n" +
            "JRrmXbUxObmC\n" +
            "-----END PRIVATE KEY-----"
        natPasswords = []string{"yzh"}
        headSize     = 128
        serverMap    = map[string]*Server{}
        mutex        sync.Mutex
        timeOut      = time.Second * 15
    )
    
    type ConnPair struct {
        serverClient *net.Conn
        client       *net.Conn
    }
    
    type Server struct {
        emptyCount   int
        address      string
        mutex        sync.Mutex
        wMutex       sync.Mutex
        natConnPairs []*ConnPair
        paddingConns []*net.Conn
    }
    
    func (server *Server) accept(listener *net.Listener, port string) {
        count := 0
        go func() {
            for {
                time.Sleep(time.Second * 60)
                if len(server.natConnPairs) == 0 {
                    server.emptyCount++
                    if server.emptyCount > 5 {
                        count = 8
                        (*listener).Close()
                        break
                    }
                } else {
                    server.emptyCount = 0
                }
            }
        }()
        for {
            client, err := (*listener).Accept()
            if err != nil {
                log.Println(err)
                count++
                if count > 5 {
                    break
                }
                time.Sleep(time.Second)
                continue
            }
            log.Println(client.LocalAddr(), client.RemoteAddr())
            server.dispatcherNat(&client)
            count = 0
        }
        serverMap[port] = nil
    }
    func (server *Server) dispatcherNat(client *net.Conn) {
        server.mutex.Lock()
        defer server.mutex.Unlock()
        if len(server.natConnPairs) == 0 {
            server.paddingConns = append(server.paddingConns, client)
        } else {
            connPair := server.natConnPairs[0]
            server.natConnPairs = server.natConnPairs[1:]
            go startNat(connPair, client, &server.wMutex)
        }
    }
    func (server *Server) addNatServerClient(connPair *ConnPair) {
        server.mutex.Lock()
        defer server.mutex.Unlock()
        server.emptyCount = 0
        if len(server.paddingConns) == 0 {
            server.natConnPairs = append(server.natConnPairs, connPair)
        } else {
            client := server.paddingConns[0]
            server.paddingConns = server.paddingConns[1:]
            go startNat(connPair, client, &server.wMutex)
        }
    }
    func (server *Server) removeNatServerClient(connPair *ConnPair) {
        server.mutex.Lock()
        defer server.mutex.Unlock()
        for i, c := range server.natConnPairs {
            if c == connPair {
                server.natConnPairs = append(server.natConnPairs[:i], server.natConnPairs[i+1:]...)
                return
            }
        }
    }
    func startNat(connPair *ConnPair, client *net.Conn, wMutex *sync.Mutex) {
        wMutex.Lock()
        connPair.client = client
        wMutex.Unlock()
        serverClient := *connPair.serverClient
        defer serverClient.Close()
        defer (*connPair.client).Close()
        serverClient.SetWriteDeadline(time.Now().Add(timeOut))
        buf := make([]byte, 1)
        buf[0] = 1
        _, err := serverClient.Write(buf)
        if err != nil {
            log.Println("Write err", err)
            return
        }
        switchData(serverClient, *connPair.client)
    }
    
    func main() {
        log.Println("入参: " + strings.Join(os.Args[1:], " "))
        dispatcherAddress := ":8989"
        if len(os.Args) == 2 {
            dispatcherAddress = os.Args[1]
        }
        log.SetFlags(log.LstdFlags | log.Lshortfile)
        log.Println("Nat分发服务地址: " + dispatcherAddress)
        privateKey, err := loadPrivateKey(privateKeyStr)
        if err != nil {
            log.Panicln(err)
        }
    
        listener, err := net.Listen("tcp", dispatcherAddress)
        if err != nil {
            log.Panicln(err)
        }
        for {
            client, err := listener.Accept()
            if err != nil {
                log.Println(err)
                time.Sleep(time.Second)
                continue
            }
            log.Println(client.LocalAddr(), client.RemoteAddr())
            go handleClientRequest(&client, privateKey)
        }
    }
    
    func handleClientRequest(client *net.Conn, privateKey *rsa.PrivateKey) {
        serverClient := *client
        defer serverClient.Close()
        serverClient.SetReadDeadline(time.Now().Add(timeOut))
        cmdData := make([]byte, headSize)
        n, err := io.ReadFull(serverClient, cmdData)
        buf := make([]byte, 1)
        buf[0] = 2
        serverClient.SetWriteDeadline(time.Now().Add(timeOut))
        if err != nil {
            log.Println("ReadFull err", err)
            serverClient.Write(buf)
            return
        }
        if n != len(cmdData) {
            log.Println("读取长度错误:", n)
            serverClient.Write(buf)
            return
        }
    
        // decryptedText, err := rsa.DecryptPKCS1v15(nil, privateKey, cmdData)
        decryptedText, err := rsa.DecryptOAEP(sha1.New(), nil, privateKey, cmdData, nil)
        if err != nil {
            log.Println("DecryptPKCS1v15 err", err)
            serverClient.Write(buf)
            return
        }
        info := string(decryptedText)
        infos := strings.Split(info, "-")
        if len(infos) != 2 {
            log.Println("infos error", infos)
            serverClient.Write(buf)
            return
        }
        pwdOk := false
        for _, v := range natPasswords {
            if v == infos[1] {
                pwdOk = true
                break
            }
        }
        if !pwdOk {
            log.Println("密码错误", infos)
            serverClient.Write(buf)
            return
        }
        server := getServer(infos[0])
        if server == nil {
            log.Println("getServer nil")
            serverClient.Write(buf)
            return
        }
        connPair := &ConnPair{serverClient: &serverClient}
        server.addNatServerClient(connPair)
    
        n = 1
        for {
            if n == 1 {
                serverClient.SetReadDeadline(time.Now().Add(timeOut))
            }
            n, err = serverClient.Read(buf)
            if err != nil {
                server.removeNatServerClient(connPair)
                return
            }
            if n != 1 {
                continue
            }
            if buf[0] == 0 {
                break
            }
            func() {
                server.wMutex.Lock()
                defer server.wMutex.Unlock()
                if connPair.client != nil {
                    log.Println("nating")
                    return
                }
                serverClient.SetWriteDeadline(time.Now().Add(timeOut))
                buf[0] = 0
                _, err = serverClient.Write(buf)
                if err != nil {
                    server.removeNatServerClient(connPair)
                    return
                }
            }()
        }
        if connPair.client == nil {
            log.Println("nat client nil")
            return
        }
        defer (*connPair.client).Close()
        switchData(*connPair.client, *connPair.serverClient)
    }
    
    func getServer(address string) *Server {
        addr := strings.Split(address, ":")
        if len(addr) != 2 {
            log.Println("address error", address)
            return nil
        }
        mutex.Lock()
        defer mutex.Unlock()
        server := serverMap[addr[1]]
        if server != nil {
            if server.address != address {
                log.Println("端口重复", server.address, address)
                return nil
            }
        } else {
            listener, err := net.Listen("tcp", address)
            if err != nil {
                log.Println("listen error", err)
                return nil
            }
            server = &Server{address: address}
            serverMap[addr[1]] = server
            go server.accept(&listener, addr[1])
        }
        return server
    }
    
    func loadPrivateKey(privateKeyStr string) (privateKey *rsa.PrivateKey, err error) {
        block, _ := pem.Decode([]byte(privateKeyStr))
        if block == nil {
            return nil, fmt.Errorf("解码私钥失败")
        }
        key, err := x509.ParsePKCS8PrivateKey(block.Bytes)
        if err != nil {
            return nil, err
        }
        privateKey, ok := key.(*rsa.PrivateKey)
        if !ok {
            return nil, fmt.Errorf("非法私钥文件")
        }
        return privateKey, nil
    }
    
    func switchData(dst net.Conn, src net.Conn) (written int64, err error) {
        buf := make([]byte, 10240)
        resetTimeOut := true
        for {
            if resetTimeOut {
                src.SetReadDeadline(time.Now().Add(timeOut))
                dst.SetReadDeadline(time.Now().Add(timeOut))
                resetTimeOut = false
            }
            nr, er := src.Read(buf)
            if nr > 0 {
                resetTimeOut = true
                dst.SetWriteDeadline(time.Now().Add(timeOut))
                nw, ew := dst.Write(buf[0:nr])
                if nw < 0 || nr < nw {
                    nw = 0
                    if ew == nil {
                        ew = fmt.Errorf("invalid write result")
                    }
                }
                written += int64(nw)
                if ew != nil {
                    err = ew
                    break
                }
                if nr != nw {
                    err = io.ErrShortWrite
                    break
                }
            }
            if er != nil {
                if er != io.EOF {
                    err = er
                }
                break
            }
        }
        if err != nil {
            log.Println(err)
        }
        return written, err
    }
    
    

    客户端:

    package main
    
    import (
        "crypto/rand"
        "crypto/rsa"
        "crypto/sha1"
        "crypto/x509"
        "encoding/pem"
        "fmt"
        "io"
        "log"
        "net"
        "os"
        "strings"
        "sync"
        "time"
    )
    
    var (
        publicKeyStr = "-----BEGIN PUBLIC KEY-----\nMIGfMA0GCSqGSIb3DQEBAQUAA4GNADCBiQKBgQCWAX9+7stLFV8sW2zA470M8b/5\nHt1FgkpGIVHfHvjIxh3k/APVfWlXpoN6lKIDQ/z4LZc+m03faeR/qjgl562W0sHQ\nDezv/cd84Uc2hDh/vTifL6RfNA7mrW3aqiVxT4gzvp327nzck/J/mzfVFyEgFb+z\nWsvr0xMkg+NNXMww8wIDAQAB\n-----END PUBLIC KEY-----\n"
    
        natDispatcherAddr = "192.168.10.18:8989"
        natServerOpenAddr = ":9001"
        localServerAddr   = "127.0.0.1:8080"
        natPassword       = "yzh"
        maxFreeNat        = 3
        rateInterval      = time.Second * 5
        timeOut           = time.Second * 15
    )
    
    func main() {
        log.Println("入参: " + strings.Join(os.Args[1:], " "))
        log.SetFlags(log.LstdFlags | log.Lshortfile)
        log.Println("Nat分发服务地址: " + natDispatcherAddr)
        log.Println("Nat服务器开放地址: " + natServerOpenAddr)
        log.Println("本地服务地址: " + localServerAddr)
        publicKey, err := loadPublicKey(publicKeyStr)
        if err != nil {
            log.Panicln(err)
        }
        natInfo, err := rsa.EncryptOAEP(sha1.New(), rand.Reader, publicKey, []byte(natServerOpenAddr+"-"+natPassword), nil)
        if err != nil {
            log.Panicln(err)
        }
        for i := 1; i < maxFreeNat; i++ {
            go start(natInfo, i)
        }
        start(natInfo, 0)
    }
    func start(natInfo []byte, index int) {
        for {
            log.Println("startNatId", index)
            startNat(natInfo)
        }
    }
    func startNat(natInfo []byte) {
        server, err := net.DialTimeout("tcp", natDispatcherAddr, timeOut)
        if err != nil {
            log.Println(err)
            time.Sleep(time.Second)
            return
        }
        server.SetWriteDeadline(time.Now().Add(timeOut))
        n, err := server.Write(natInfo)
        if err != nil {
            log.Println(err)
            server.Close()
            return
        }
        if n != len(natInfo) {
            log.Println("写入长度错误", n, len(natInfo))
            server.Close()
            return
        }
    
        mutex := &sync.Mutex{}
        nating := false
        go func() {
            buf := make([]byte, 1)
            buf[0] = 1
            n = 1
            for {
                time.Sleep(rateInterval)
                mutex.Lock()
                if nating {
                    break
                }
                if n == 1 {
                    server.SetWriteDeadline(time.Now().Add(timeOut))
                }
                n, err = server.Write(buf)
                if err != nil {
                    log.Println(err)
                    break
                }
                mutex.Unlock()
            }
            mutex.Unlock()
        }()
    
        buf := make([]byte, 1)
        r := 1
        for {
            if r == 1 {
                if buf[0] == 1 {
                    break
                }
                if buf[0] == 2 {
                    log.Panicln("config error or port used", natServerOpenAddr)
                }
                server.SetReadDeadline(time.Now().Add(timeOut))
            }
            r, err = server.Read(buf)
            if err != nil {
                log.Println(err)
                server.Close()
                return
            }
        }
        mutex.Lock()
        nating = true
        mutex.Unlock()
        go startServer(&server)
    }
    
    func startServer(client *net.Conn) {
        buf := make([]byte, 1)
        buf[0] = 0
        defer (*client).Close()
        (*client).SetWriteDeadline(time.Now().Add(timeOut))
        n, err := (*client).Write(buf)
        if err != nil {
            log.Println(err)
            return
        }
        if n != 1 {
            log.Println("写入开始命令错误")
            return
        }
        server, err := net.DialTimeout("tcp", localServerAddr, timeOut)
        if err != nil {
            log.Println(err)
            return
        }
        defer server.Close()
        go switchData(*client, server)
        switchData(server, *client)
    }
    
    func loadPublicKey(publicKeyStr string) (publicKey *rsa.PublicKey, err error) {
        block, _ := pem.Decode([]byte(publicKeyStr))
        if block == nil {
            return nil, fmt.Errorf("解码公钥失败")
        }
        key, err := x509.ParsePKIXPublicKey(block.Bytes)
        if err != nil {
            return nil, err
        }
        publicKey, ok := key.(*rsa.PublicKey)
        if !ok {
            return nil, fmt.Errorf("非法公钥文件")
        }
        return publicKey, nil
    }
    
    func switchData(dst net.Conn, src net.Conn) (written int64, err error) {
        buf := make([]byte, 10240)
        resetTimeOut := true
        for {
            if resetTimeOut {
                src.SetReadDeadline(time.Now().Add(timeOut))
                dst.SetReadDeadline(time.Now().Add(timeOut))
                resetTimeOut = false
            }
            nr, er := src.Read(buf)
            if nr > 0 {
                resetTimeOut = true
                dst.SetWriteDeadline(time.Now().Add(timeOut))
                nw, ew := dst.Write(buf[0:nr])
                if nw < 0 || nr < nw {
                    nw = 0
                    if ew == nil {
                        ew = fmt.Errorf("invalid write result")
                    }
                }
                written += int64(nw)
                if ew != nil {
                    err = ew
                    break
                }
                if nr != nw {
                    err = io.ErrShortWrite
                    break
                }
            }
            if er != nil {
                if er != io.EOF {
                    err = er
                }
                break
            }
        }
        if err != nil {
            log.Println(err)
        }
        return written, err
    }
    
    

    由于未找到较好的读取加密的私钥的方案,服务端的私钥改为了未加密的私钥

    相关文章

      网友评论

          本文标题:go语言实现内网穿透

          本文链接:https://www.haomeiwen.com/subject/fytsadtx.html