在 web server 项目中,我们只用了单线程,但是现实中,都是用多线程/多进程等方式来提高并发性能。这一章,我们使用线程池来优化这个 web server 项目。
20.2 使用线程池优化
使用线程池的好处是:
- 充分利用多核处理器优势,并行处理请求
- 限制最大线程数,避免 DoS (Denial of Service)攻击
- 避免反复创建线程的开销
线程池原理
一般的线程池主要由两个部分组成:
- 多个 Worker 线程
- 一个任务队列
每个 Worker 线程会尝试从任务队列获取任务并执行,若任务队列为空则阻塞。因此,任务队列是一个 multiple consumer 队列,至于 producer,在 web server 的应用场景中只有一个,即接收连接的线程。
线程池应该主要提供以下几个接口:
-
ThreadPool::new
:在构造时,分配指定数量的线程作为 Worker -
ThreadPool::execute
:往队列中添加任务 -
ThreadPool::drop
:在析构时,保证所有的任务执行完毕
Rust 简易线程池实现
在 Rust 中,std::sync::mpsc
提供了一个 multiple producer single consumer 的队列。我们需要手动为 consumer 端做同步来实现一个支持 multiple consumer 的队列。
pub fn channel<T>() -> (Sender<T>, Receiver<T>)
通过 std::sync::mpsc::channel
方法生成 Sender 和 Receiver,ThreadPool 持有队列的 Sender,在 execute
时往里面添加任务。而所有的 Worker 共同持有队列的 Receiver,并不断尝试从 Receiver 中获取任务。
我们对“任务”和“共享Receiver”的定义如下:
type Job = Box<dyn FnOnce() + Send + 'static>;
// all workers get task from it
type SharedTaskReceiver = Arc<Mutex<mpsc::Receiver<Job>>>;
ThreadPool
基于这个思想,我们很容易写出 ThreadPool 的实现:
pub struct ThreadPool {
workers: Vec<Worker>,
task_sender: mpsc::Sender<Job>,
}
impl ThreadPool {
pub fn new(size: usize) -> Result<ThreadPool, &'static str> {
if size == 0 {
return Err("A thread pool with 0 thread is not allowed");
}
let (sender, receiver) = mpsc::channel();
let shared_receiver = Arc::new(Mutex::new(receiver));
let mut workers = vec![];
for i in 0..size {
workers.push(Worker::new(i, Arc::clone(&shared_receiver)));
}
Ok(ThreadPool {
workers: workers,
task_sender: sender,
})
}
pub fn execute<F: FnOnce() + Send + 'static>(&self, func: F) {
self.task_sender.send(Box::new(func)).unwrap();
}
}
非常值得注意的是,我们使用了Arc<Mutex<Receiver>>
类型来表示我们需要在所有 Worker 中“共享”这个 Receiver。
Worker
对于 Worker,我们让它在 loop
中尝试从 Receiver 获取任务。
struct Worker {
id: usize,
handle: std::thread::JoinHandle<()>, // task return type is empty '()'
}
impl Worker {
pub fn new(id: usize, task_receiver: SharedTaskReceiver) -> Worker {
let handle = std::thread::spawn(
move || {
loop {
let task = task_receiver.lock().unwrap().recv();
if task.is_ok() {
let f = task.unwrap();
println!("Worker {} got a job, executing...", id);
f();
}
}
});
Worker{id: id, handle: handle}
}
}
值得注意的是,lock
和 recv
方法都是阻塞的。例如,线程 A 最先拿到锁,然而没有任务,就会阻塞在 recv
方法上,也不会释放锁。其他线程则会在 lock
方法阻塞。
等待 Worker 执行任务
我们将为 ThreadPool 实现 Drop
trait 来 join 所有的线程,保证队列中的任务完成。
impl Drop for ThreadPool {
fn drop(&mut self) {
for worker in self.workers.iter_mut() {
println!("Shutting down worker {}", worker.id);
worker.handle.join().unwrap();
}
}
}
以上的简单实现会报错:
error[E0507]: cannot move out of `worker.handle` which is behind a mutable reference
这是因为,join
会拿走 JoinHandle
的 ownership,而 drop
的参数是 &mut self
,仅仅是一个引用。
那我们怎样才能从一个引用得到其中内容的 ownership 呢?
答案是将 Worker
own 的 JoinHandle
改成 Option<JoinHandle>
,这样,就可以使用 Option::take
方法来拿走里面的 JoinHandle
的 ownership。但是从 Worker
看来,它仍然 own 和一个 Option<JoinHandle>
,只是里面内容变成了 None
而已。
这样做是合理的,并不是一个 workaround,因为在析构时的 for 循环中,首先被 join
的 Worker
实际上已经不 own 一个 JoinHandle
了,而其他未被 join
的 Worker
还 own 一个 JoinHandle
,因此它确实是一个 Option
。本质上,Rust 的 ownership 使我们的程序更加严谨了。
struct Worker {
id: usize,
// task return type is empty '()'
// when drop, some threads which joined first will not own JoinHandle
handle: Option<std::thread::JoinHandle<()>>,
}
在 drop
时,使用 take
将 Option<JoinHandle>
替换成 None
,并拿走 ownership:
impl Drop for ThreadPool {
fn drop(&mut self) {
for worker in self.workers.iter_mut() {
println!("Shutting down worker {}", worker.id);
if let Some(handle) = worker.handle.take() {
handle.join().unwrap();
}
}
}
}
做了这些改动,再尝试运行,发现程序不是直接退出了,但是 hang 在那里,也不是我们期望的行为。这是因为 Worker 运行一个 loop
,无法结束,所以 join
了就会永远等待下去。
停止 Worker
我们将修改程序使 ThreadPool 接收停止信号以退出 loop
。 首先,队列接收的不止是 Job
了,还可能接收关闭信号 Terminate
:
enum Message {
Task(Job),
Terminate,
}
此外,当 Worker 处理时,对于 Terminate
信号,需要打破无限循环 loop
:
impl Worker {
pub fn new(id: usize, task_receiver: SharedTaskReceiver) -> Worker {
let handle = std::thread::spawn(move || loop {
let message = task_receiver.lock().unwrap().recv().unwrap();
match message {
Message::Task(job) => {
println!("Worker {} got a job, executing...", id);
job();
}
Message::Terminate => {
println!("Worker {} was told to terminate.", id);
break;
}
}
});
Worker {
id: id,
handle: Some(handle),
}
}
}
那么,什么时候发送 Terminate
信号呢?在析构的时候。
impl Drop for ThreadPool {
fn drop(&mut self) {
println!(
"Sending terminate message to {} workers",
self.workers.len()
);
for _ in self.workers.iter() {
self.task_sender.send(Message::Terminate).unwrap();
}
for worker in self.workers.iter_mut() {
println!("Shutting down worker {}", worker.id);
if let Some(handle) = worker.handle.take() {
handle.join().unwrap();
}
}
}
}
这里首先对 n 个 Worker 发送了 n 个 Terminate,由于接收到了 Terminate 的 Worker 不会再处理消息,因此每个 Worker 恰好消耗一个 Terminate 消息。
同时我们还注意到,我们用了两个 for 循环来分别发送 Terminate 和 join thread,这是为了避免死锁。假设一个简单的情况,只有 2 个 Worker A 和 B。如果是一个 for 循环既发 Terminate 又 join,则可能会出现下面场景:
- 发送了一个 Terminate,A 还在执行 job,B 收到 Terminate 并退出
- 尝试 join A,但是由于 A 还未收到 Terminate,所以一直等待
- for 循环卡住
至此,一个基本的线程池已经实现。全部代码见附录。
附录
以下是全部代码:
// lib.rs
use std::sync::{mpsc, Arc, Mutex};
type Job = Box<dyn FnOnce() + Send + 'static>;
enum Message {
Task(Job),
Terminate,
}
// all workers get task from it
type SharedTaskReceiver = Arc<Mutex<mpsc::Receiver<Message>>>;
struct Worker {
id: usize,
// task return type is empty '()'
// when drop, some threads which joined first will not own JoinHandle
handle: Option<std::thread::JoinHandle<()>>,
}
impl Worker {
pub fn new(id: usize, task_receiver: SharedTaskReceiver) -> Worker {
let handle = std::thread::spawn(move || loop {
let message = task_receiver.lock().unwrap().recv().unwrap();
match message {
Message::Task(job) => {
println!("Worker {} got a job, executing...", id);
job();
}
Message::Terminate => {
println!("Worker {} was told to terminate.", id);
break;
}
}
});
Worker {
id: id,
handle: Some(handle),
}
}
}
pub struct ThreadPool {
workers: Vec<Worker>,
task_sender: mpsc::Sender<Message>,
}
impl ThreadPool {
pub fn new(size: usize) -> Result<ThreadPool, &'static str> {
if size == 0 {
return Err("A thread pool with 0 thread is not allowed");
}
let (sender, receiver) = mpsc::channel();
let shared_receiver = Arc::new(Mutex::new(receiver));
let mut workers = vec![];
for i in 0..size {
workers.push(Worker::new(i, Arc::clone(&shared_receiver)));
}
Ok(ThreadPool {
workers: workers,
task_sender: sender,
})
}
pub fn execute<F: FnOnce() + Send + 'static>(&self, func: F) {
let new_job = Message::Task(Box::new(func));
self.task_sender.send(new_job).unwrap();
}
}
impl Drop for ThreadPool {
fn drop(&mut self) {
println!(
"Sending terminate message to {} workers",
self.workers.len()
);
for _ in self.workers.iter() {
self.task_sender.send(Message::Terminate).unwrap();
}
for worker in self.workers.iter_mut() {
println!("Shutting down worker {}", worker.id);
if let Some(handle) = worker.handle.take() {
handle.join().unwrap();
}
}
}
}
// main.rs
use std::fs;
use std::io::prelude::{Read, Write};
use std::net::{TcpListener, TcpStream};
fn handle_client(mut stream: TcpStream) {
let mut buffer = [0; 1024];
stream.read(&mut buffer).unwrap();
// GET request prefix
let get = b"GET / HTTP/1.1\r\n";
let mut status_line = "HTTP/1.1 200 OK";
let mut filename = "hello.html";
if buffer.starts_with(get) == false {
status_line = "HTTP/1.1 404 NOT FOUND";
filename = "404.html";
}
let contents = fs::read_to_string(filename).unwrap();
let response = format!(
"{}\r\nContent-Length: {}\r\n\r\n{}",
status_line,
contents.len(),
contents
);
stream.write(response.as_bytes()).unwrap();
stream.flush().unwrap();
}
fn main() {
let listener = TcpListener::bind("127.0.0.1:7878").unwrap();
let pool = web_server::ThreadPool::new(4).unwrap();
for stream in listener.incoming() {
match stream {
Ok(stream) => pool.execute(|| {
handle_client(stream);
}),
Err(e) => println!("connection failed: {}", e),
}
}
}
打开 127.0.0.1:7878
,不断刷新可以看到不同的线程在处理请求:
Worker 0 got a job, executing...
Worker 1 got a job, executing...
Worker 2 got a job, executing...
Worker 3 got a job, executing...
Worker 2 got a job, executing...
Worker 0 got a job, executing...
Worker 1 got a job, executing...
Worker 3 got a job, executing...
Worker 2 got a job, executing...
Worker 0 got a job, executing...
Worker 1 got a job, executing...
Worker 3 got a job, executing...
单元测试
以下是一个简单的单元测试,希望测试两个方面:
- 任务是并行的,这个可以用测试的运行时间判断
- 所有任务执行完毕,这个可以通过
assert_eq!
判断
#[cfg(test)]
mod tests {
use std::sync::{Arc, Mutex};
use std::time::Duration;
#[test]
fn slow_tasks_in_parallel() {
let mut tasks = vec![];
let counter = Arc::new(Mutex::new(0));
let total = 200;
for i in 0..total {
let counter1 = Arc::clone(&counter);
tasks.push(move || {
std::thread::sleep(Duration::from_millis(i));
let mut num = counter1.lock().unwrap();
*num += 1;
});
}
let pool = crate::ThreadPool::new(10).unwrap();
for task in tasks {
pool.execute(task);
}
// guarantee to call ThreadPool::drop before check
std::mem::drop(pool);
assert_eq!(*counter.lock().unwrap(), total);
}
}
非常值得注意的是我们手动调用了
std::mem::drop(pool);
这样可以保证在检查之前所有任务执行完毕。assert_eq!
失败会导致 panic,在多线程中,panic 非常难 debug,报错信息为:
running 1 test
thread panicked while panicking. aborting.
error: test failed, to rerun pass '--lib'
Caused by:
process didn't exit successfully: `~/Project/rust/web_server/target/debug/deps/web_server-0af50e188492a41d` (signal: 4, SIGILL: illegal instruction)
浪费了不少时间才找到这个bug。
最终运行时间是 2.14s
,根据我们的代码,10 个线程一共睡眠了 0+1+..+199 = 19900ms = 19.9s,说明成功并行。
网友评论