相关内容
node.js实现内网穿透: https://www.jianshu.com/p/d2d4f8bff599
kotlin实现内网穿透: https://www.jianshu.com/p/c8dc095c758e
可以和node.js, kotlin版混用
使用方式见node.js版,大同小异,部分逻辑稍有调整
实现代码
服务端:
package main
import (
"crypto/rsa"
"crypto/sha1"
"crypto/x509"
"encoding/pem"
"fmt"
"io"
"log"
"net"
"os"
"strings"
"sync"
"time"
)
var (
privateKeyStr = "-----BEGIN PRIVATE KEY-----\n" +
"MIICdQIBADANBgkqhkiG9w0BAQEFAASCAl8wggJbAgEAAoGBAJYBf37uy0sVXyxb\n" +
"bMDjvQzxv/ke3UWCSkYhUd8e+MjGHeT8A9V9aVemg3qUogND/Pgtlz6bTd9p5H+q\n" +
"OCXnrZbSwdAN7O/9x3zhRzaEOH+9OJ8vpF80DuatbdqqJXFPiDO+nfbufNyT8n+b\n" +
"N9UXISAVv7Nay+vTEySD401czDDzAgMBAAECgYAE9u26TLrrtDxfInN5+s+R8xpQ\n" +
"a2YVW9eLdKTaBpNjSbNJldGmqiznWrp1PyARjZl8uT2NM+Si5UVLuF19W6qSC1Tw\n" +
"bn2brBSsFsufKQff0XXRsD8OqKT5h5/PlRATYgisZonD/v0SPfDHROcNCFiOelEv\n" +
"kKZAJgJS3vghZx5jCQJBAMTb4esPOApK+6wXlfVShvRiac9UWW9KBYLwpqya5Cjd\n" +
"X/QUGVwywMFkNgRnBz7yWdJd1zuXfuI77N87BXYOE68CQQDDEjfaKCyZJ73RfNJo\n" +
"ED5DRfZXm86RPBDlZznJ4shjNVbk1sGClYAk0WuAHeIZJVpm1HpME6NSCfumqeOq\n" +
"z1P9AkAQu+xJcgK+hT89ksexkfFc5ty9vhrYJf+v8MsKUyRgAOl+MxMwzjOqfN1G\n" +
"pIduJ2XRRx7btvYXPybUlwzQy0OLAkBIexlznt/LXH/kOcv4TKjF2FYLAWKEhlwE\n" +
"0REg2Xn5mtUZnE40lhYSGBoodXIQQ9fOQ37Zi6ZwkjMGHzPvwK+FAkAe9gHRMI2u\n" +
"lzwVp/AQntBMXmw92IULIlfRmfV1jDBYuT0JHUUGZqfCrz+iDW8Ot24QBzLbwKxJ\n" +
"JRrmXbUxObmC\n" +
"-----END PRIVATE KEY-----"
natPasswords = []string{"yzh"}
headSize = 128
serverMap = map[string]*Server{}
mutex sync.Mutex
timeOut = time.Second * 15
)
type ConnPair struct {
serverClient *net.Conn
client *net.Conn
}
type Server struct {
emptyCount int
address string
mutex sync.Mutex
wMutex sync.Mutex
natConnPairs []*ConnPair
paddingConns []*net.Conn
}
func (server *Server) accept(listener *net.Listener, port string) {
count := 0
go func() {
for {
time.Sleep(time.Second * 60)
if len(server.natConnPairs) == 0 {
server.emptyCount++
if server.emptyCount > 5 {
count = 8
(*listener).Close()
break
}
} else {
server.emptyCount = 0
}
}
}()
for {
client, err := (*listener).Accept()
if err != nil {
log.Println(err)
count++
if count > 5 {
break
}
time.Sleep(time.Second)
continue
}
log.Println(client.LocalAddr(), client.RemoteAddr())
server.dispatcherNat(&client)
count = 0
}
serverMap[port] = nil
}
func (server *Server) dispatcherNat(client *net.Conn) {
server.mutex.Lock()
defer server.mutex.Unlock()
if len(server.natConnPairs) == 0 {
server.paddingConns = append(server.paddingConns, client)
} else {
connPair := server.natConnPairs[0]
server.natConnPairs = server.natConnPairs[1:]
go startNat(connPair, client, &server.wMutex)
}
}
func (server *Server) addNatServerClient(connPair *ConnPair) {
server.mutex.Lock()
defer server.mutex.Unlock()
server.emptyCount = 0
if len(server.paddingConns) == 0 {
server.natConnPairs = append(server.natConnPairs, connPair)
} else {
client := server.paddingConns[0]
server.paddingConns = server.paddingConns[1:]
go startNat(connPair, client, &server.wMutex)
}
}
func (server *Server) removeNatServerClient(connPair *ConnPair) {
server.mutex.Lock()
defer server.mutex.Unlock()
for i, c := range server.natConnPairs {
if c == connPair {
server.natConnPairs = append(server.natConnPairs[:i], server.natConnPairs[i+1:]...)
return
}
}
}
func startNat(connPair *ConnPair, client *net.Conn, wMutex *sync.Mutex) {
wMutex.Lock()
connPair.client = client
wMutex.Unlock()
serverClient := *connPair.serverClient
defer serverClient.Close()
defer (*connPair.client).Close()
serverClient.SetWriteDeadline(time.Now().Add(timeOut))
buf := make([]byte, 1)
buf[0] = 1
_, err := serverClient.Write(buf)
if err != nil {
log.Println("Write err", err)
return
}
switchData(serverClient, *connPair.client)
}
func main() {
log.Println("入参: " + strings.Join(os.Args[1:], " "))
dispatcherAddress := ":8989"
if len(os.Args) == 2 {
dispatcherAddress = os.Args[1]
}
log.SetFlags(log.LstdFlags | log.Lshortfile)
log.Println("Nat分发服务地址: " + dispatcherAddress)
privateKey, err := loadPrivateKey(privateKeyStr)
if err != nil {
log.Panicln(err)
}
listener, err := net.Listen("tcp", dispatcherAddress)
if err != nil {
log.Panicln(err)
}
for {
client, err := listener.Accept()
if err != nil {
log.Println(err)
time.Sleep(time.Second)
continue
}
log.Println(client.LocalAddr(), client.RemoteAddr())
go handleClientRequest(&client, privateKey)
}
}
func handleClientRequest(client *net.Conn, privateKey *rsa.PrivateKey) {
serverClient := *client
defer serverClient.Close()
serverClient.SetReadDeadline(time.Now().Add(timeOut))
cmdData := make([]byte, headSize)
n, err := io.ReadFull(serverClient, cmdData)
buf := make([]byte, 1)
buf[0] = 2
serverClient.SetWriteDeadline(time.Now().Add(timeOut))
if err != nil {
log.Println("ReadFull err", err)
serverClient.Write(buf)
return
}
if n != len(cmdData) {
log.Println("读取长度错误:", n)
serverClient.Write(buf)
return
}
// decryptedText, err := rsa.DecryptPKCS1v15(nil, privateKey, cmdData)
decryptedText, err := rsa.DecryptOAEP(sha1.New(), nil, privateKey, cmdData, nil)
if err != nil {
log.Println("DecryptPKCS1v15 err", err)
serverClient.Write(buf)
return
}
info := string(decryptedText)
infos := strings.Split(info, "-")
if len(infos) != 2 {
log.Println("infos error", infos)
serverClient.Write(buf)
return
}
pwdOk := false
for _, v := range natPasswords {
if v == infos[1] {
pwdOk = true
break
}
}
if !pwdOk {
log.Println("密码错误", infos)
serverClient.Write(buf)
return
}
server := getServer(infos[0])
if server == nil {
log.Println("getServer nil")
serverClient.Write(buf)
return
}
connPair := &ConnPair{serverClient: &serverClient}
server.addNatServerClient(connPair)
n = 1
for {
if n == 1 {
serverClient.SetReadDeadline(time.Now().Add(timeOut))
}
n, err = serverClient.Read(buf)
if err != nil {
server.removeNatServerClient(connPair)
return
}
if n != 1 {
continue
}
if buf[0] == 0 {
break
}
func() {
server.wMutex.Lock()
defer server.wMutex.Unlock()
if connPair.client != nil {
log.Println("nating")
return
}
serverClient.SetWriteDeadline(time.Now().Add(timeOut))
buf[0] = 0
_, err = serverClient.Write(buf)
if err != nil {
server.removeNatServerClient(connPair)
return
}
}()
}
if connPair.client == nil {
log.Println("nat client nil")
return
}
defer (*connPair.client).Close()
switchData(*connPair.client, *connPair.serverClient)
}
func getServer(address string) *Server {
addr := strings.Split(address, ":")
if len(addr) != 2 {
log.Println("address error", address)
return nil
}
mutex.Lock()
defer mutex.Unlock()
server := serverMap[addr[1]]
if server != nil {
if server.address != address {
log.Println("端口重复", server.address, address)
return nil
}
} else {
listener, err := net.Listen("tcp", address)
if err != nil {
log.Println("listen error", err)
return nil
}
server = &Server{address: address}
serverMap[addr[1]] = server
go server.accept(&listener, addr[1])
}
return server
}
func loadPrivateKey(privateKeyStr string) (privateKey *rsa.PrivateKey, err error) {
block, _ := pem.Decode([]byte(privateKeyStr))
if block == nil {
return nil, fmt.Errorf("解码私钥失败")
}
key, err := x509.ParsePKCS8PrivateKey(block.Bytes)
if err != nil {
return nil, err
}
privateKey, ok := key.(*rsa.PrivateKey)
if !ok {
return nil, fmt.Errorf("非法私钥文件")
}
return privateKey, nil
}
func switchData(dst net.Conn, src net.Conn) (written int64, err error) {
buf := make([]byte, 10240)
resetTimeOut := true
for {
if resetTimeOut {
src.SetReadDeadline(time.Now().Add(timeOut))
dst.SetReadDeadline(time.Now().Add(timeOut))
resetTimeOut = false
}
nr, er := src.Read(buf)
if nr > 0 {
resetTimeOut = true
dst.SetWriteDeadline(time.Now().Add(timeOut))
nw, ew := dst.Write(buf[0:nr])
if nw < 0 || nr < nw {
nw = 0
if ew == nil {
ew = fmt.Errorf("invalid write result")
}
}
written += int64(nw)
if ew != nil {
err = ew
break
}
if nr != nw {
err = io.ErrShortWrite
break
}
}
if er != nil {
if er != io.EOF {
err = er
}
break
}
}
if err != nil {
log.Println(err)
}
return written, err
}
客户端:
package main
import (
"crypto/rand"
"crypto/rsa"
"crypto/sha1"
"crypto/x509"
"encoding/pem"
"fmt"
"io"
"log"
"net"
"os"
"strings"
"sync"
"time"
)
var (
publicKeyStr = "-----BEGIN PUBLIC KEY-----\nMIGfMA0GCSqGSIb3DQEBAQUAA4GNADCBiQKBgQCWAX9+7stLFV8sW2zA470M8b/5\nHt1FgkpGIVHfHvjIxh3k/APVfWlXpoN6lKIDQ/z4LZc+m03faeR/qjgl562W0sHQ\nDezv/cd84Uc2hDh/vTifL6RfNA7mrW3aqiVxT4gzvp327nzck/J/mzfVFyEgFb+z\nWsvr0xMkg+NNXMww8wIDAQAB\n-----END PUBLIC KEY-----\n"
natDispatcherAddr = "192.168.10.18:8989"
natServerOpenAddr = ":9001"
localServerAddr = "127.0.0.1:8080"
natPassword = "yzh"
maxFreeNat = 3
rateInterval = time.Second * 5
timeOut = time.Second * 15
)
func main() {
log.Println("入参: " + strings.Join(os.Args[1:], " "))
log.SetFlags(log.LstdFlags | log.Lshortfile)
log.Println("Nat分发服务地址: " + natDispatcherAddr)
log.Println("Nat服务器开放地址: " + natServerOpenAddr)
log.Println("本地服务地址: " + localServerAddr)
publicKey, err := loadPublicKey(publicKeyStr)
if err != nil {
log.Panicln(err)
}
natInfo, err := rsa.EncryptOAEP(sha1.New(), rand.Reader, publicKey, []byte(natServerOpenAddr+"-"+natPassword), nil)
if err != nil {
log.Panicln(err)
}
for i := 1; i < maxFreeNat; i++ {
go start(natInfo, i)
}
start(natInfo, 0)
}
func start(natInfo []byte, index int) {
for {
log.Println("startNatId", index)
startNat(natInfo)
}
}
func startNat(natInfo []byte) {
server, err := net.DialTimeout("tcp", natDispatcherAddr, timeOut)
if err != nil {
log.Println(err)
time.Sleep(time.Second)
return
}
server.SetWriteDeadline(time.Now().Add(timeOut))
n, err := server.Write(natInfo)
if err != nil {
log.Println(err)
server.Close()
return
}
if n != len(natInfo) {
log.Println("写入长度错误", n, len(natInfo))
server.Close()
return
}
mutex := &sync.Mutex{}
nating := false
go func() {
buf := make([]byte, 1)
buf[0] = 1
n = 1
for {
time.Sleep(rateInterval)
mutex.Lock()
if nating {
break
}
if n == 1 {
server.SetWriteDeadline(time.Now().Add(timeOut))
}
n, err = server.Write(buf)
if err != nil {
log.Println(err)
break
}
mutex.Unlock()
}
mutex.Unlock()
}()
buf := make([]byte, 1)
r := 1
for {
if r == 1 {
if buf[0] == 1 {
break
}
if buf[0] == 2 {
log.Panicln("config error or port used", natServerOpenAddr)
}
server.SetReadDeadline(time.Now().Add(timeOut))
}
r, err = server.Read(buf)
if err != nil {
log.Println(err)
server.Close()
return
}
}
mutex.Lock()
nating = true
mutex.Unlock()
go startServer(&server)
}
func startServer(client *net.Conn) {
buf := make([]byte, 1)
buf[0] = 0
defer (*client).Close()
(*client).SetWriteDeadline(time.Now().Add(timeOut))
n, err := (*client).Write(buf)
if err != nil {
log.Println(err)
return
}
if n != 1 {
log.Println("写入开始命令错误")
return
}
server, err := net.DialTimeout("tcp", localServerAddr, timeOut)
if err != nil {
log.Println(err)
return
}
defer server.Close()
go switchData(*client, server)
switchData(server, *client)
}
func loadPublicKey(publicKeyStr string) (publicKey *rsa.PublicKey, err error) {
block, _ := pem.Decode([]byte(publicKeyStr))
if block == nil {
return nil, fmt.Errorf("解码公钥失败")
}
key, err := x509.ParsePKIXPublicKey(block.Bytes)
if err != nil {
return nil, err
}
publicKey, ok := key.(*rsa.PublicKey)
if !ok {
return nil, fmt.Errorf("非法公钥文件")
}
return publicKey, nil
}
func switchData(dst net.Conn, src net.Conn) (written int64, err error) {
buf := make([]byte, 10240)
resetTimeOut := true
for {
if resetTimeOut {
src.SetReadDeadline(time.Now().Add(timeOut))
dst.SetReadDeadline(time.Now().Add(timeOut))
resetTimeOut = false
}
nr, er := src.Read(buf)
if nr > 0 {
resetTimeOut = true
dst.SetWriteDeadline(time.Now().Add(timeOut))
nw, ew := dst.Write(buf[0:nr])
if nw < 0 || nr < nw {
nw = 0
if ew == nil {
ew = fmt.Errorf("invalid write result")
}
}
written += int64(nw)
if ew != nil {
err = ew
break
}
if nr != nw {
err = io.ErrShortWrite
break
}
}
if er != nil {
if er != io.EOF {
err = er
}
break
}
}
if err != nil {
log.Println(err)
}
return written, err
}
由于未找到较好的读取加密的私钥的方案,服务端的私钥改为了未加密的私钥
网友评论