美文网首页
go语言实现http代理

go语言实现http代理

作者: 今天i你好吗 | 来源:发表于2024-01-06 23:14 被阅读0次

参考资料: https://zh.mojotv.cn/tutorial/golang-interface-reader-writer

资料里面的代码在代理http协议时所有的请求行中资源路径都是绝对路径的问题, 部分服务器无法识别.

本文对此进行修复和其它的一些优化

package main

import (
    "bufio"
    "fmt"
    "io"
    "log"
    "net"
    "net/url"
    "strconv"
    "strings"
)

func main() {
    proxyAddress := ":8989"
    log.SetFlags(log.LstdFlags | log.Lshortfile)
    log.Println("代理地址: " + proxyAddress)
    listener, err := net.Listen("tcp", proxyAddress)
    if err != nil {
        log.Panic(err)
    }
    for {
        client, err := listener.Accept()
        if err != nil {
            log.Panic(err)
        }
        go handleClientRequest(client)
    }
}

func handleClientRequest(client net.Conn) {
    if client == nil {
        return
    }
    defer client.Close()

    clientReader := bufio.NewReader(client)

    method, requestAddress, protocol, headers, headerLines, err := decodeHeader(clientReader)
    if err != nil {
        log.Println(err)
        return
    }
    nFirstLine := method + " " + requestAddress + " " + protocol
    var serverAddress, oldHost string
    if method == "CONNECT" {
        serverAddress = requestAddress
    } else {
        hostPortURL, err := url.Parse(requestAddress)
        if err != nil {
            log.Println("url解析错误: " + nFirstLine)
            log.Println(err)
            return
        }
        oldHost = hostPortURL.Host
        if !strings.Contains(oldHost, ":") {
            serverAddress = oldHost + ":80"
        } else {
            serverAddress = oldHost
        }
    }

    log.Println(nFirstLine + " " + serverAddress + "\n")
    server, err := net.Dial("tcp", serverAddress)
    if err != nil {
        log.Println(err)
        return
    }
    defer server.Close()
    if method == "CONNECT" {
        fmt.Fprint(client, "HTTP/1.1 200 Connection established\r\n\r\n")
        go io.Copy(server, clientReader)
    } else {
        needDecodeHeader := false
        go func() {
            for {
                if needDecodeHeader {
                    method, requestAddress, protocol, headers, headerLines, err = decodeHeader(clientReader)
                    if err != nil {
                        log.Println(err)
                        return
                    }
                } else {
                    needDecodeHeader = true
                }
                requestPath := append(strings.Split(requestAddress, oldHost), "/")[1]
                server.Write([]byte(method + " " + requestPath + " " + protocol + "\r\n"))
                for _, line := range headerLines {
                    server.Write([]byte(line))
                }
                server.Write([]byte("\r\n"))

                length64, err := strconv.ParseInt(headers["content-length"], 10, 64)
                if err == nil {
                    if length64 == -1 {
                        io.Copy(server, clientReader)
                        return
                    }
                    limitedReader := io.LimitReader(clientReader, length64)
                    io.Copy(server, limitedReader)
                    limitedReader = io.LimitReader(clientReader, 2)
                    io.Copy(server, limitedReader)
                }
            }
        }()
    }
    io.Copy(client, server)
}

func decodeHeader(render *bufio.Reader) (string, string, string, map[string]string, []string, error) {
    var method, requestAddress, protocol string
    var headers = map[string]string{}
    var headerLines = []string{}
    lineData, err := render.ReadBytes('\n')
    if err != nil {
        return method, requestAddress, protocol, headers, headerLines, err
    }

    line := string(lineData)
    fmt.Sscanf(line, "%s%s%s", &method, &requestAddress, &protocol)
    if line != method+" "+requestAddress+" "+protocol+"\r\n" {
        log.Println("解析错误: " + line)
    }
    for {
        lineData, err := render.ReadBytes('\n')
        if err != nil {
            return method, requestAddress, protocol, headers, headerLines, err
        }
        if len(lineData) == 2 {
            break
        }
        line := string(lineData)
        index := strings.Index(line, ":")
        keyLower := strings.ToLower(strings.Trim(line[:index], "\r\n "))
        value := line[index+1:]
        if strings.HasPrefix(keyLower, "proxy-") {
            log.Println(line)
        }
        headers[keyLower] = strings.Trim(value, "\r\n ")
        if keyLower == "proxy-connection" {
            headerLines = append(headerLines, "Connection:"+value)
        } else {
            headerLines = append(headerLines, line)
        }
    }
    return method, requestAddress, protocol, headers, headerLines, err
}


优化前后请求体对比:


image.png
编译好的产物:

https://download.csdn.net/download/qq_37873556/88739142
和文档中的代码有稍微差异, 添加了超时机制等优化
提供如下系统产物
go env -w GOOS=linux GOARCH=amd64
go install
go env -w GOOS=darwin GOARCH=amd64
go install
go env -w GOOS=darwin GOARCH=arm64
go install
go env -w GOOS=linux GOARCH=arm GOARM=5
go install
go env -w GOOS=windows GOARCH=amd64
go install

相关文章

网友评论

      本文标题:go语言实现http代理

      本文链接:https://www.haomeiwen.com/subject/moiwndtx.html