美文网首页
C++11 构建基于工作偷取队列的线程池

C++11 构建基于工作偷取队列的线程池

作者: FredricZhu | 来源:发表于2023-11-11 20:36 被阅读0次

    本例是《CPP Concurrency in Action》 一书的大结局,后面的例子都是和thread_interupt有关的。就没有太关注了。

    所谓简单的线程池其实就是 一个线程安全的queue,里面放若干tasks,然后起若干个线程,不断从queue里面取任务执行。
    而基于工作偷取队列的线程池,略有不同,不同点就在于,每个线程都有一个thread_local的工作偷取队列,然后这些工作偷取队列被放在一个vector中。取任务的优先级是,先获取当前线程thread_local队列的task,如果取不到,取全局的线程安全队列的task, 如果还取不到,那就去偷取其他线程thread_local队列中的任务[注意这里偷取的是队尾的任务,就是积压未执行的任务]。通过这样搞的好处是,增加了执行任务的机会,提高吞吐量。
    代码如下,
    conanfile.txt

    [requires]
    boost/1.72.0
    
    [generators]
    cmake
    

    CMakeLists.txt

    cmake_minimum_required(VERSION 3.3)
    
    project(9_5_thread_pool_sorter)
    
    set(ENV{PKG_CONFIG_PATH} "$ENV{PKG_CONFIG_PATH}:/usr/local/lib/pkgconfig/")
    
    set ( CMAKE_CXX_FLAGS "-pthread")
    set(CMAKE_CXX_STANDARD 17)
    add_definitions(-g)
    
    include(${CMAKE_BINARY_DIR}/conanbuildinfo.cmake)
    conan_basic_setup()
    
    include_directories(${INCLUDE_DIRS})
    LINK_DIRECTORIES(${LINK_DIRS})
    
    file( GLOB main_file_list ${CMAKE_CURRENT_SOURCE_DIR}/*.cpp) 
    
    foreach( main_file ${main_file_list} )
        file(RELATIVE_PATH filename ${CMAKE_CURRENT_SOURCE_DIR} ${main_file})
        string(REPLACE ".cpp" "" file ${filename})
        add_executable(${file}  ${main_file})
        target_link_libraries(${file} ${CONAN_LIBS} pthread)
    endforeach( main_file ${main_file_list})
    

    function_wrapper.hpp

    #ifndef _FREDRIC_FUNC_WRAPPER_HPP_
    #define _FREDRIC_FUNC_WRAPPER_HPP_
    
    #include <memory>
    
    class function_wrapper {
        struct impl_base {
            virtual void call() = 0;
            virtual ~impl_base() {}
        };
    
        template <typename F>
        struct impl_type: impl_base {
            F f;
            impl_type(F&& f_): f(std::move(f_)) {}
    
            void call() {
                f();
            }
        };
        
        std::unique_ptr<impl_base> impl;
    
    public:
        function_wrapper() {}
    
        // 这个wrapper wrapper的是 packaged_task
        template <typename F>
        function_wrapper(F&& f):
            impl(new impl_type<F>(std::move(f))) {}
    
        void call() {
            impl->call();
        }
    
        function_wrapper(function_wrapper&& other): impl(std::move(other.impl)) {}
    
        function_wrapper& operator=(function_wrapper&& other) {
            impl = std::move(other.impl);
            return *this;
        }
    
        function_wrapper(function_wrapper const&) = delete;
        function_wrapper(function_wrapper&) = delete;
        function_wrapper& operator=(function_wrapper const&) = delete;
    };
    #endif
    

    work_stealing_queue.hpp

    #ifndef _FREDRIC_WORK_STEAL_QUEUE_HPP_
    #define _FREDRIC_WORK_STEAL_QUEUE_HPP_
    #include "function_wrapper.hpp"
    #include <deque>
    #include <mutex>
    
    
    // 简单的双端队列的 线程安全的版本
    // 注意try_pop 是从头部走,搞最新的task
    // try_steal是从尾部走,搞最老的task
    class work_stealing_queue {
    private:
        typedef function_wrapper data_type;
        std::deque<data_type> the_queue;
        mutable std::mutex the_mutex;
    
    public:
        work_stealing_queue() {}
        work_stealing_queue(work_stealing_queue const&) = delete;
        work_stealing_queue& operator=(work_stealing_queue const&) = delete;
    
        void push(data_type data) {
            std::lock_guard<std::mutex> lock(the_mutex);
            the_queue.push_front(std::move(data));
        }
    
        bool empty() const {
            std::lock_guard<std::mutex> lock(the_mutex);
            return the_queue.empty();
        }
    
        bool try_pop(data_type& res) {
            std::lock_guard<std::mutex> lock(the_mutex);
            if(the_queue.empty()) {
                return false;
            }
    
            res = std::move(the_queue.front());
            the_queue.pop_front();
            return true;
        }
    
        bool try_steal(data_type& res) {
            std::lock_guard<std::mutex> lock(the_mutex);
            if(the_queue.empty()) {
                return false;
            }
            res = std::move(the_queue.back());
            the_queue.pop_back();
            return true;
        }
    
    
    };
    #endif
    

    threadsafe_queue.hpp

    #ifndef _FREDRIC_THREAD_SAFE_QUEUE_HPP_
    #define _FREDRIC_THREAD_SAFE_QUEUE_HPP_
    
    #include <mutex>
    #include <string>
    #include <queue>
    #include <memory>
    #include <atomic>
    #include <condition_variable>
    #include <exception>
    
    template <typename T>
    class threadsafe_queue {
    private:
        struct node {
            std::shared_ptr<T> data;
            std::unique_ptr<node> next;
        };
    
        std::mutex head_mutex;
        std::mutex tail_mutex;
    
        std::unique_ptr<node> head;
        node* tail;
    
        std::condition_variable data_cond;
        
        node* get_tail() {
            std::lock_guard<std::mutex> tail_lock(tail_mutex);
            return tail;
        }
    
        std::unique_ptr<node> pop_head() {
            std::unique_ptr<node> old_head = std::move(head);
            head = std::move(old_head->next);
            return old_head;
        }
    
        std::unique_lock<std::mutex> wait_for_data() {
            std::unique_lock<std::mutex> head_lock(head_mutex);
            data_cond.wait(head_lock, [&]() {
                return head.get() != get_tail();
            });
    
            return std::move(head_lock);
        }
    
        std::unique_ptr<node> wait_pop_head() {
            std::unique_lock<std::mutex> head_lock(wait_for_data());
            return pop_head();
        }
    
        std::unique_ptr<node> wait_pop_head(T& value) {
            std::unique_lock<std::mutex> head_lock(wait_for_data());
            value = std::move(*head->data);
            return pop_head();
        }
    
        std::unique_ptr<node> try_pop_head() {
            std::lock_guard<std::mutex> head_lock(head_mutex);
            if(head.get() == get_tail()) {
                return std::unique_ptr<node>();
            }
            return pop_head();
        }
    
        std::unique_ptr<node> try_pop_head(T& value) {
            std::lock_guard<std::mutex> head_lock(head_mutex);
            if(head.get() == get_tail()) {
                return std::unique_ptr<node>();
            }
            value = std::move(*head->data);
            return pop_head();
        } 
    
    public:
        threadsafe_queue():
            head(new node), tail(head.get()) {}
        
        threadsafe_queue(threadsafe_queue const&) = delete;
        threadsafe_queue& operator=(threadsafe_queue const&) = delete;
    
        void push(T new_value) {
            std::shared_ptr<T> new_data(std::make_shared<T>(std::move(new_value)));
            std::unique_ptr<node> p (new node);
            {
                std::lock_guard<std::mutex> tail_lock(tail_mutex);
                tail->data = new_data;
                node* const new_tail = p.get();
                tail->next = std::move(p);
                tail = new_tail;
            }
    
            data_cond.notify_one();
        }
    
        std::shared_ptr<T> wait_and_pop() {
            std::unique_ptr<node> const old_head = wait_pop_head();
            return old_head->data;
        }
    
        void wait_and_pop(T& value) {
            wait_pop_head(value);
        }
    
        bool empty() {
            std::lock_guard<std::mutex> head_lock(head_mutex);
            return (head.get() == get_tail());
        }
    
        std::shared_ptr<T> try_pop() {
            std::unique_ptr<node> old_head = try_pop_head();
            return old_head ? old_head->data: std::shared_ptr<T>();
        }
    
        bool try_pop(T& value) {
            std::unique_ptr<node> old_head = try_pop_head(value);
            return old_head != nullptr;
        }
    };
    
    #endif
    

    thread_pool.hpp

    #ifndef _FREDRIC_THREAD_POOL_HPP_
    #define _FREDRIC_THREAD_POOL_HPP_
    #include "function_wrapper.hpp"
    #include "thread_safe_queue.hpp"
    #include "work_stealing_queue.hpp"
    #include <thread>
    #include <vector>
    #include <atomic>
    #include <functional>
    #include <utility>
    #include <future>
    #include <queue>
    #include <utility>
    
    #include <functional>
    #include <memory>
    
    struct join_threads {
    
        std::thread& operator[](int index) {
            return threads[index];
        }
    
        void add_thread(std::thread&& thread) {
            threads.emplace_back(std::move(thread));
        }
    
        ~join_threads() {
            for(std::thread& thread: threads) {
                if(thread.joinable()) {
                    thread.join();
                }
            }
        }
    private:
        std::vector<std::thread> threads;
    };
    
    
    class thread_pool {
        typedef function_wrapper task_type;
    
        std::atomic<bool> done;
        threadsafe_queue<task_type> pool_work_queue;
        std::vector<std::unique_ptr<work_stealing_queue>> queues;
        join_threads joiner;
    
        static thread_local work_stealing_queue* local_work_queue;
        static thread_local unsigned my_index;
        static unsigned thread_count;
    
        void work_thread(unsigned my_index_) {
            my_index = my_index_;
            // 从work_steal_queue列表里面根据索引查找当前的local_work_queue
            local_work_queue = queues[my_index].get();
            while(!done) {
                run_pending_task();
            }
        }
    
        bool pop_task_from_local_queue(task_type& task) {
            return local_work_queue && local_work_queue->try_pop(task);
        }
    
        bool pop_task_from_pool_queue(task_type& task) {
            return pool_work_queue.try_pop(task);
        }
    
        bool pop_task_from_other_thread_queue(task_type& task) {
            for(unsigned i=0; i<queues.size(); ++i) {
                unsigned const index = (my_index + i +1) % queues.size();
                // TODO: 这里有问题,等会再调, 要保证work_stealing_queue vector建完才行
                if(queues.size() == thread_count && (queues[index]->try_steal(task))) {
                    return true;
                }
            }
    
            return false;
        }
    
    public:
        thread_pool():
            done(false) {
            thread_count = std::thread::hardware_concurrency();
    
            try {
                for(unsigned i=0; i<thread_count; ++i) {
                    // 一个线程一个work_stealing_queue队列
                    queues.push_back(std::unique_ptr<work_stealing_queue>(new work_stealing_queue));
                    joiner.add_thread(std::thread(&thread_pool::work_thread, this, i));
                }
            } catch(...) {
                done = true;
                throw;
            }
        }
    
        ~thread_pool() {
            done = true;
        }
    
        template <typename FunctionType>
        std::future<typename std::result_of<FunctionType()>::type> submit(FunctionType f) {
            typedef typename std::result_of<FunctionType()>::type result_type;
            std::packaged_task<result_type()> task(std::move(f));
            std::future<result_type> res = task.get_future();
    
            if(local_work_queue) {
                local_work_queue->push(std::move(task));
            } else {
                pool_work_queue.push(std::move(task));
            }
            return res;
        }
    
        void run_pending_task() {
            task_type task;
            // 优先级,先拿 local_work_queue里面的,
            // 再拿全局的pool_work_queue里面的,[前面两个都是拿队头最新的]
            // 没有的话,再偷取其他队列队尾的 [后面是偷取队尾最旧的]
            if(pop_task_from_local_queue(task) ||
                pop_task_from_pool_queue(task) ||
                pop_task_from_other_thread_queue(task)) {
                task.call();
            }else {
                std::this_thread::yield();
            }
        }
    };
    
    thread_local work_stealing_queue* thread_pool::local_work_queue {nullptr};
    thread_local unsigned thread_pool::my_index = 0;
    unsigned thread_pool::thread_count = 0;
    #endif
    

    main.cpp

    #include "thread_pool.hpp"
    #include <iostream>
    #include <algorithm>
    #include <numeric>
    #include <list>
    #include <chrono>
    
    template <typename T>
    struct  sorter {
        thread_pool pool;
    
        std::list<T> do_sort(std::list<T>& chunk_data) {
            if(chunk_data.empty()) {
                return chunk_data;
            }
    
            std::list<T> result;
            result.splice(result.begin(), chunk_data, chunk_data.begin());
            T const& partition_val = *result.begin();
            typename std::list<T>::iterator divide_point = std::partition(
                chunk_data.begin(), chunk_data.end(),
                [&](T const& val) {
                    return val < partition_val;
                }
            );
    
            std::list<T> new_lower_chunk;
            new_lower_chunk.splice(new_lower_chunk.end(), 
                chunk_data, chunk_data.begin(), divide_point);
            
            std::future<std::list<T>> new_lower = pool.submit([this,&new_lower_chunk]() {
                return do_sort(new_lower_chunk);
            }); 
    
            std::list<T> new_higher(do_sort(chunk_data));
            result.splice(result.end(), new_higher);
    
            while(true) {
                std::future_status status = new_lower.wait_for(std::chrono::seconds(0));
                if(status == std::future_status::ready) {
                    break;
                }
                pool.run_pending_task();
            }
            result.splice(result.begin(), new_lower.get());
            return result;
        }
    };
    
    
    template <typename T>
    std::list<T> parallel_quick_sort(std::list<T> input) {
        if(input.empty()) {
            return input;
        }
    
        sorter<T> s;
        return s.do_sort(input);
    }
    
    
    int main(int argc, char* argv[]) {
        std::list<int> ls;
    
        for(std::size_t i=0; i<10000; ++i) {
            if(i < 5000) {
                ls.push_back(i + 5000);
            } else {
                ls.push_back(i-5000);
            }
        }
    
    
        std::for_each(ls.begin(), ls.end(), [](int const& ele){
            std::cout << ele << " ";
        });
        std::cout << std::endl;
    
        std::list<int> res = parallel_quick_sort(ls);
    
        std::cout << "After parallel_quick_sort, result: " << std::endl;
        std::for_each(res.begin(), res.end(), [](int const& ele){
            std::cout << ele << " ";
        });
        std::cout << std::endl;
        return EXIT_SUCCESS;
    }
    

    程序输出如下,注意排序的list长度不能太大,这和thread_pool无关,主要是因为sorter仿函数里面用了递归,递归栈程度过深可能导致栈溢出。


    image.png

    相关文章

      网友评论

          本文标题:C++11 构建基于工作偷取队列的线程池

          本文链接:https://www.haomeiwen.com/subject/qfhlwdtx.html