go语言版: https://www.jianshu.com/p/56eebc0d7035
因为进行代理时需要在字节流和字符流之间进行切换所以我使用了okio库: "com.squareup.okio:okio:1.17.2"
虽然使用BufferedReader也可以同时读取字符流和字节流但是读取字节流时每次只能读取一个, 即使BufferedReader也有缓存机制但还是不确定对性能的影响有多大,未进行过相关的测试
![](https://img.haomeiwen.com/i8215544/077275be54a93359.png)
源码:
import okio.BufferedSink
import okio.BufferedSource
import okio.Okio
import java.net.*
import java.util.concurrent.Executors
import java.util.concurrent.TimeUnit
/**
* 作者:yzh
*
* 创建时间:2024/1/6 23:39
*
* 描述:
*
* 修订历史:
*/
object HttpProxy {
private val executor by lazy { Executors.newCachedThreadPool { Thread(it, "connect==") } }
private val executorA by lazy { Executors.newCachedThreadPool { Thread(it, "客户端==") } }
private val executorB by lazy { Executors.newCachedThreadPool { Thread(it, "服务端==") } }
fun start(localPort: Int, localAddr: String? = null) {
val serverSocket = if (localAddr.isNullOrBlank()) {
ServerSocket(localPort)
} else {
ServerSocket(localPort, 50, InetAddress.getByName(localAddr))
}
while (true) {
val accept = serverSocket.accept()
executor.execute {
handlerRequest(
accept
)
}
}
}
private fun handlerRequest(clientSocket: Socket) {
val timeOut = 60
val clientSource = Okio.buffer(Okio.source(clientSocket.getInputStream()))
val clientSink = Okio.buffer(Okio.sink(clientSocket.getOutputStream()))
clientSource.timeout().timeout(timeOut.toLong(), TimeUnit.SECONDS)
clientSink.timeout().timeout(timeOut.toLong(), TimeUnit.SECONDS)
val httpsType = "CONNECT"
var method = ""
var requestAddress = ""
var protocol = ""
val headers = hashMapOf<String, String>()
val headerLines = arrayListOf<String>()
fun decodeHeader(): Boolean {
if (clientSocket.isClosed) return false
method = ""
requestAddress = ""
protocol = ""
headers.clear()
headerLines.clear()
var line = try {
clientSource.readUtf8Line() ?: return false
} catch (e: Exception) {
return false
}
line.split(" ").run {
method = this[0]
requestAddress = this[1]
protocol = this[2]
}
val decodeLine = "$method $requestAddress $protocol"
println(decodeLine)
if (line != decodeLine) println("解析错误: $line")
while (clientSource.readUtf8Line()?.also { line = it } != null && line.trim().isNotEmpty()) {
if (line.startsWith("Proxy-", true)) println(line)
line = line.replace("Proxy-Connection: ", "Connection: ", true)
val index = line.indexOf(':')
if (index >= 0) {
headers[line.substring(0, index).trim().toLowerCase()] =
line.substring(index + 1).trim()
headerLines.add(line)
}
}
return true
}
if (!decodeHeader()) return
var oldHost = ""
val serverSocketAddress = if (httpsType == method) {
requestAddress.run {
val lastIndexOf = lastIndexOf(":")
InetSocketAddress(substring(0, lastIndexOf), substring(lastIndexOf + 1).toInt())
}
} else {
URL(requestAddress).run {
oldHost = host + if (port == -1) "" else ":$port"
InetSocketAddress(host, if (port == -1) 80 else port)
}
}
// val proxy = Proxy(Proxy.Type.SOCKS, InetSocketAddress("127.0.0.1", 1080))
// val remoteSocket = Socket(proxy)
val serverSocket = Socket()
try {
serverSocket.soTimeout = timeOut * 1000
serverSocket.connect(serverSocketAddress)
} catch (e: Exception) {
return
}
fun switchData(sink: BufferedSink, source: BufferedSource) {
val buffer = ByteArray(1024 * 2)
var length = 0
while (!clientSocket.isClosed && !serverSocket.isClosed &&
try {
source.read(buffer).also { length = it } > -1
} catch (e: Exception) {
false
}
) {
try {
sink.write(buffer, 0, length)
sink.flush()
} catch (e: Exception) {
}
}
try {
source.close()
} catch (e: Exception) {
}
try {
sink.close()
} catch (e: Exception) {
}
}
val serverSink = Okio.buffer(Okio.sink(serverSocket.getOutputStream()))
val serverSource = Okio.buffer(Okio.source(serverSocket.getInputStream()))
serverSink.timeout().timeout(timeOut.toLong(), TimeUnit.SECONDS)
serverSource.timeout().timeout(timeOut.toLong(), TimeUnit.SECONDS)
if (httpsType == method) {
clientSink.writeUtf8("HTTP/1.1 200 Connection established\r\n\r\n")
clientSink.flush()
executorA.execute {
switchData(serverSink, clientSource)
}
} else {
executorA.execute {
var needDecodeHeader = false
while (!clientSocket.isClosed && !serverSocket.isClosed) {
if (needDecodeHeader) {
if (!decodeHeader()) return@execute
} else needDecodeHeader = true
val requestPath = requestAddress.split(oldHost).getOrNull(1) ?: "/"
serverSink.writeUtf8("$method $requestPath $protocol\r\n")
for (headerLine in headerLines) {
serverSink.writeUtf8("$headerLine\r\n")
}
serverSink.writeUtf8("\r\n")
serverSink.flush()
headers["content-length"]?.trim()?.toLong()?.run {
try {
if (this == -1L) {
switchData(serverSink, clientSource)
return@execute
}
serverSink.write(clientSource, this)
serverSink.flush()
serverSink.write(clientSource, 2)
} catch (e: Exception) {
return@execute
}
}
}
}
}
executorB.execute {
switchData(clientSink, serverSource)
try {
serverSocket.close()
} catch (e: Exception) {
}
try {
clientSocket.close()
} catch (e: Exception) {
}
}
}
}
网友评论