美文网首页
Rust for cpp dev - 线程池

Rust for cpp dev - 线程池

作者: 找不到工作 | 来源:发表于2021-05-26 19:38 被阅读0次

    web server 项目中,我们只用了单线程,但是现实中,都是用多线程/多进程等方式来提高并发性能。这一章,我们使用线程池来优化这个 web server 项目。

    20.2 使用线程池优化

    使用线程池的好处是:

    • 充分利用多核处理器优势,并行处理请求
    • 限制最大线程数,避免 DoS (Denial of Service)攻击
    • 避免反复创建线程的开销

    线程池原理

    一般的线程池主要由两个部分组成:

    1. 多个 Worker 线程
    2. 一个任务队列

    每个 Worker 线程会尝试从任务队列获取任务并执行,若任务队列为空则阻塞。因此,任务队列是一个 multiple consumer 队列,至于 producer,在 web server 的应用场景中只有一个,即接收连接的线程。

    线程池应该主要提供以下几个接口:

    • ThreadPool::new:在构造时,分配指定数量的线程作为 Worker
    • ThreadPool::execute:往队列中添加任务
    • ThreadPool::drop:在析构时,保证所有的任务执行完毕

    Rust 简易线程池实现

    在 Rust 中,std::sync::mpsc 提供了一个 multiple producer single consumer 的队列。我们需要手动为 consumer 端做同步来实现一个支持 multiple consumer 的队列。

    pub fn channel<T>() -> (Sender<T>, Receiver<T>)
    

    通过 std::sync::mpsc::channel 方法生成 Sender 和 Receiver,ThreadPool 持有队列的 Sender,在 execute 时往里面添加任务。而所有的 Worker 共同持有队列的 Receiver,并不断尝试从 Receiver 中获取任务。

    我们对“任务”和“共享Receiver”的定义如下:

    type Job = Box<dyn FnOnce() + Send + 'static>;
    
    // all workers get task from it
    type SharedTaskReceiver = Arc<Mutex<mpsc::Receiver<Job>>>;
    

    ThreadPool

    基于这个思想,我们很容易写出 ThreadPool 的实现:

    pub struct ThreadPool {
        workers: Vec<Worker>,
        task_sender: mpsc::Sender<Job>,
    }
    
    
    impl ThreadPool {
        pub fn new(size: usize) -> Result<ThreadPool, &'static str> {
            if size == 0 {
                return Err("A thread pool with 0 thread is not allowed");
            }
    
            let (sender, receiver) = mpsc::channel();
            let shared_receiver = Arc::new(Mutex::new(receiver));
    
            let mut workers = vec![];
            for i in 0..size {
                workers.push(Worker::new(i, Arc::clone(&shared_receiver)));
            }
    
            Ok(ThreadPool {
                workers: workers,
                task_sender: sender,
            })
        }
    
        pub fn execute<F: FnOnce() + Send + 'static>(&self, func: F) {
            self.task_sender.send(Box::new(func)).unwrap();
        }
    }
    

    非常值得注意的是,我们使用了Arc<Mutex<Receiver>>类型来表示我们需要在所有 Worker 中“共享”这个 Receiver。

    Worker

    对于 Worker,我们让它在 loop 中尝试从 Receiver 获取任务。

    struct Worker {
        id: usize,
        handle: std::thread::JoinHandle<()>, // task return type is empty '()'
    }
    
    impl Worker {
        pub fn new(id: usize, task_receiver: SharedTaskReceiver) -> Worker {
            let handle = std::thread::spawn(
                move || {
                    loop {
                        let task = task_receiver.lock().unwrap().recv();
                        if task.is_ok() {
                            let f = task.unwrap();
                            println!("Worker {} got a job, executing...", id);
                            f();
                        }
                    }
                });
            Worker{id: id, handle: handle}
        }
    }
    

    值得注意的是,lockrecv 方法都是阻塞的。例如,线程 A 最先拿到锁,然而没有任务,就会阻塞在 recv 方法上,也不会释放锁。其他线程则会在 lock 方法阻塞。

    等待 Worker 执行任务

    我们将为 ThreadPool 实现 Drop trait 来 join 所有的线程,保证队列中的任务完成。

    impl Drop for ThreadPool {
        fn drop(&mut self) {
            for worker in self.workers.iter_mut() {
                println!("Shutting down worker {}", worker.id);
                worker.handle.join().unwrap();
            }
        }
    }
    

    以上的简单实现会报错:

    error[E0507]: cannot move out of `worker.handle` which is behind a mutable reference
    

    这是因为,join 会拿走 JoinHandle 的 ownership,而 drop 的参数是 &mut self,仅仅是一个引用。

    那我们怎样才能从一个引用得到其中内容的 ownership 呢?

    答案是将 Worker own 的 JoinHandle 改成 Option<JoinHandle>,这样,就可以使用 Option::take 方法来拿走里面的 JoinHandle 的 ownership。但是从 Worker 看来,它仍然 own 和一个 Option<JoinHandle>,只是里面内容变成了 None 而已。

    这样做是合理的,并不是一个 workaround,因为在析构时的 for 循环中,首先被 joinWorker 实际上已经不 own 一个 JoinHandle 了,而其他未被 joinWorker 还 own 一个 JoinHandle,因此它确实是一个 Option。本质上,Rust 的 ownership 使我们的程序更加严谨了。

    struct Worker {
        id: usize,
        // task return type is empty '()'
        // when drop, some threads which joined first will not own JoinHandle
        handle: Option<std::thread::JoinHandle<()>>,
    }
    

    drop 时,使用 takeOption<JoinHandle> 替换成 None,并拿走 ownership:

    impl Drop for ThreadPool {
        fn drop(&mut self) {
            for worker in self.workers.iter_mut() {
                println!("Shutting down worker {}", worker.id);
                if let Some(handle) = worker.handle.take() {
                    handle.join().unwrap();
                }
            }
        }
    }
    

    做了这些改动,再尝试运行,发现程序不是直接退出了,但是 hang 在那里,也不是我们期望的行为。这是因为 Worker 运行一个 loop,无法结束,所以 join了就会永远等待下去。

    停止 Worker

    我们将修改程序使 ThreadPool 接收停止信号以退出 loop。 首先,队列接收的不止是 Job 了,还可能接收关闭信号 Terminate

    enum Message {
        Task(Job),
        Terminate,
    }
    

    此外,当 Worker 处理时,对于 Terminate 信号,需要打破无限循环 loop

    impl Worker {
        pub fn new(id: usize, task_receiver: SharedTaskReceiver) -> Worker {
            let handle = std::thread::spawn(move || loop {
                let message = task_receiver.lock().unwrap().recv().unwrap();
                match message {
                    Message::Task(job) => {
                        println!("Worker {} got a job, executing...", id);
                        job();
                    }
                    Message::Terminate => {
                        println!("Worker {} was told to terminate.", id);
                        break;
                    }
                }
            });
            Worker {
                id: id,
                handle: Some(handle),
            }
        }
    }
    

    那么,什么时候发送 Terminate 信号呢?在析构的时候。

    impl Drop for ThreadPool {
        fn drop(&mut self) {
            println!(
                "Sending terminate message to {} workers",
                self.workers.len()
            );
            for _ in self.workers.iter() {
                self.task_sender.send(Message::Terminate).unwrap();
            }
    
            for worker in self.workers.iter_mut() {
                println!("Shutting down worker {}", worker.id);
                if let Some(handle) = worker.handle.take() {
                    handle.join().unwrap();
                }
            }
        }
    }
    

    这里首先对 n 个 Worker 发送了 n 个 Terminate,由于接收到了 Terminate 的 Worker 不会再处理消息,因此每个 Worker 恰好消耗一个 Terminate 消息。

    同时我们还注意到,我们用了两个 for 循环来分别发送 Terminate 和 join thread,这是为了避免死锁。假设一个简单的情况,只有 2 个 Worker A 和 B。如果是一个 for 循环既发 Terminate 又 join,则可能会出现下面场景:

    • 发送了一个 Terminate,A 还在执行 job,B 收到 Terminate 并退出
    • 尝试 join A,但是由于 A 还未收到 Terminate,所以一直等待
    • for 循环卡住

    至此,一个基本的线程池已经实现。全部代码见附录。

    附录

    以下是全部代码:

    // lib.rs
    use std::sync::{mpsc, Arc, Mutex};
    
    type Job = Box<dyn FnOnce() + Send + 'static>;
    
    enum Message {
        Task(Job),
        Terminate,
    }
    
    // all workers get task from it
    type SharedTaskReceiver = Arc<Mutex<mpsc::Receiver<Message>>>;
    
    struct Worker {
        id: usize,
        // task return type is empty '()'
        // when drop, some threads which joined first will not own JoinHandle
        handle: Option<std::thread::JoinHandle<()>>,
    }
    
    impl Worker {
        pub fn new(id: usize, task_receiver: SharedTaskReceiver) -> Worker {
            let handle = std::thread::spawn(move || loop {
                let message = task_receiver.lock().unwrap().recv().unwrap();
                match message {
                    Message::Task(job) => {
                        println!("Worker {} got a job, executing...", id);
                        job();
                    }
                    Message::Terminate => {
                        println!("Worker {} was told to terminate.", id);
                        break;
                    }
                }
            });
            Worker {
                id: id,
                handle: Some(handle),
            }
        }
    }
    
    pub struct ThreadPool {
        workers: Vec<Worker>,
        task_sender: mpsc::Sender<Message>,
    }
    
    impl ThreadPool {
        pub fn new(size: usize) -> Result<ThreadPool, &'static str> {
            if size == 0 {
                return Err("A thread pool with 0 thread is not allowed");
            }
    
            let (sender, receiver) = mpsc::channel();
            let shared_receiver = Arc::new(Mutex::new(receiver));
    
            let mut workers = vec![];
            for i in 0..size {
                workers.push(Worker::new(i, Arc::clone(&shared_receiver)));
            }
    
            Ok(ThreadPool {
                workers: workers,
                task_sender: sender,
            })
        }
    
        pub fn execute<F: FnOnce() + Send + 'static>(&self, func: F) {
            let new_job = Message::Task(Box::new(func));
            self.task_sender.send(new_job).unwrap();
        }
    }
    
    impl Drop for ThreadPool {
        fn drop(&mut self) {
            println!(
                "Sending terminate message to {} workers",
                self.workers.len()
            );
            for _ in self.workers.iter() {
                self.task_sender.send(Message::Terminate).unwrap();
            }
    
            for worker in self.workers.iter_mut() {
                println!("Shutting down worker {}", worker.id);
                if let Some(handle) = worker.handle.take() {
                    handle.join().unwrap();
                }
            }
        }
    }
    
    // main.rs
    use std::fs;
    use std::io::prelude::{Read, Write};
    use std::net::{TcpListener, TcpStream};
    
    fn handle_client(mut stream: TcpStream) {
        let mut buffer = [0; 1024];
        stream.read(&mut buffer).unwrap();
    
        // GET request prefix
        let get = b"GET / HTTP/1.1\r\n";
    
        let mut status_line = "HTTP/1.1 200 OK";
        let mut filename = "hello.html";
    
        if buffer.starts_with(get) == false {
            status_line = "HTTP/1.1 404 NOT FOUND";
            filename = "404.html";
        }
    
        let contents = fs::read_to_string(filename).unwrap();
        let response = format!(
            "{}\r\nContent-Length: {}\r\n\r\n{}",
            status_line,
            contents.len(),
            contents
        );
        stream.write(response.as_bytes()).unwrap();
        stream.flush().unwrap();
    }
    
    fn main() {
        let listener = TcpListener::bind("127.0.0.1:7878").unwrap();
        let pool = web_server::ThreadPool::new(4).unwrap();
    
        for stream in listener.incoming() {
            match stream {
                Ok(stream) => pool.execute(|| {
                    handle_client(stream);
                }),
                Err(e) => println!("connection failed: {}", e),
            }
        }
    }
    

    打开 127.0.0.1:7878,不断刷新可以看到不同的线程在处理请求:

    Worker 0 got a job, executing...
    Worker 1 got a job, executing...
    Worker 2 got a job, executing...
    Worker 3 got a job, executing...
    Worker 2 got a job, executing...
    Worker 0 got a job, executing...
    Worker 1 got a job, executing...
    Worker 3 got a job, executing...
    Worker 2 got a job, executing...
    Worker 0 got a job, executing...
    Worker 1 got a job, executing...
    Worker 3 got a job, executing...
    

    单元测试

    以下是一个简单的单元测试,希望测试两个方面:

    1. 任务是并行的,这个可以用测试的运行时间判断
    2. 所有任务执行完毕,这个可以通过 assert_eq! 判断
    #[cfg(test)]
    mod tests {
        use std::sync::{Arc, Mutex};
        use std::time::Duration;
    
        #[test]
        fn slow_tasks_in_parallel() {
            let mut tasks = vec![];
    
            let counter = Arc::new(Mutex::new(0));
            let total = 200;
            for i in 0..total {
                let counter1 = Arc::clone(&counter);
                tasks.push(move || {
                    std::thread::sleep(Duration::from_millis(i));
                    let mut num = counter1.lock().unwrap();
                    *num += 1;
                });
            }
    
            let pool = crate::ThreadPool::new(10).unwrap();
    
            for task in tasks {
                pool.execute(task);
            }
    
            // guarantee to call ThreadPool::drop before check
            std::mem::drop(pool);
    
            assert_eq!(*counter.lock().unwrap(), total);
        }
    }
    

    非常值得注意的是我们手动调用了

    std::mem::drop(pool);
    

    这样可以保证在检查之前所有任务执行完毕。assert_eq! 失败会导致 panic,在多线程中,panic 非常难 debug,报错信息为:

    running 1 test
    thread panicked while panicking. aborting.
    error: test failed, to rerun pass '--lib'
    
    Caused by:
      process didn't exit successfully: `~/Project/rust/web_server/target/debug/deps/web_server-0af50e188492a41d` (signal: 4, SIGILL: illegal instruction)
    

    浪费了不少时间才找到这个bug。

    最终运行时间是 2.14s,根据我们的代码,10 个线程一共睡眠了 0+1+..+199 = 19900ms = 19.9s,说明成功并行。

    相关文章

      网友评论

          本文标题:Rust for cpp dev - 线程池

          本文链接:https://www.haomeiwen.com/subject/euahsltx.html