本例是《CPP Concurrency in Action》 一书的大结局,后面的例子都是和thread_interupt有关的。就没有太关注了。
所谓简单的线程池其实就是 一个线程安全的queue,里面放若干tasks,然后起若干个线程,不断从queue里面取任务执行。
而基于工作偷取队列的线程池,略有不同,不同点就在于,每个线程都有一个thread_local的工作偷取队列,然后这些工作偷取队列被放在一个vector中。取任务的优先级是,先获取当前线程thread_local队列的task,如果取不到,取全局的线程安全队列的task, 如果还取不到,那就去偷取其他线程thread_local队列中的任务[注意这里偷取的是队尾的任务,就是积压未执行的任务]。通过这样搞的好处是,增加了执行任务的机会,提高吞吐量。
代码如下,
conanfile.txt
[requires]
boost/1.72.0
[generators]
cmake
CMakeLists.txt
cmake_minimum_required(VERSION 3.3)
project(9_5_thread_pool_sorter)
set(ENV{PKG_CONFIG_PATH} "$ENV{PKG_CONFIG_PATH}:/usr/local/lib/pkgconfig/")
set ( CMAKE_CXX_FLAGS "-pthread")
set(CMAKE_CXX_STANDARD 17)
add_definitions(-g)
include(${CMAKE_BINARY_DIR}/conanbuildinfo.cmake)
conan_basic_setup()
include_directories(${INCLUDE_DIRS})
LINK_DIRECTORIES(${LINK_DIRS})
file( GLOB main_file_list ${CMAKE_CURRENT_SOURCE_DIR}/*.cpp)
foreach( main_file ${main_file_list} )
file(RELATIVE_PATH filename ${CMAKE_CURRENT_SOURCE_DIR} ${main_file})
string(REPLACE ".cpp" "" file ${filename})
add_executable(${file} ${main_file})
target_link_libraries(${file} ${CONAN_LIBS} pthread)
endforeach( main_file ${main_file_list})
function_wrapper.hpp
#ifndef _FREDRIC_FUNC_WRAPPER_HPP_
#define _FREDRIC_FUNC_WRAPPER_HPP_
#include <memory>
class function_wrapper {
struct impl_base {
virtual void call() = 0;
virtual ~impl_base() {}
};
template <typename F>
struct impl_type: impl_base {
F f;
impl_type(F&& f_): f(std::move(f_)) {}
void call() {
f();
}
};
std::unique_ptr<impl_base> impl;
public:
function_wrapper() {}
// 这个wrapper wrapper的是 packaged_task
template <typename F>
function_wrapper(F&& f):
impl(new impl_type<F>(std::move(f))) {}
void call() {
impl->call();
}
function_wrapper(function_wrapper&& other): impl(std::move(other.impl)) {}
function_wrapper& operator=(function_wrapper&& other) {
impl = std::move(other.impl);
return *this;
}
function_wrapper(function_wrapper const&) = delete;
function_wrapper(function_wrapper&) = delete;
function_wrapper& operator=(function_wrapper const&) = delete;
};
#endif
work_stealing_queue.hpp
#ifndef _FREDRIC_WORK_STEAL_QUEUE_HPP_
#define _FREDRIC_WORK_STEAL_QUEUE_HPP_
#include "function_wrapper.hpp"
#include <deque>
#include <mutex>
// 简单的双端队列的 线程安全的版本
// 注意try_pop 是从头部走,搞最新的task
// try_steal是从尾部走,搞最老的task
class work_stealing_queue {
private:
typedef function_wrapper data_type;
std::deque<data_type> the_queue;
mutable std::mutex the_mutex;
public:
work_stealing_queue() {}
work_stealing_queue(work_stealing_queue const&) = delete;
work_stealing_queue& operator=(work_stealing_queue const&) = delete;
void push(data_type data) {
std::lock_guard<std::mutex> lock(the_mutex);
the_queue.push_front(std::move(data));
}
bool empty() const {
std::lock_guard<std::mutex> lock(the_mutex);
return the_queue.empty();
}
bool try_pop(data_type& res) {
std::lock_guard<std::mutex> lock(the_mutex);
if(the_queue.empty()) {
return false;
}
res = std::move(the_queue.front());
the_queue.pop_front();
return true;
}
bool try_steal(data_type& res) {
std::lock_guard<std::mutex> lock(the_mutex);
if(the_queue.empty()) {
return false;
}
res = std::move(the_queue.back());
the_queue.pop_back();
return true;
}
};
#endif
threadsafe_queue.hpp
#ifndef _FREDRIC_THREAD_SAFE_QUEUE_HPP_
#define _FREDRIC_THREAD_SAFE_QUEUE_HPP_
#include <mutex>
#include <string>
#include <queue>
#include <memory>
#include <atomic>
#include <condition_variable>
#include <exception>
template <typename T>
class threadsafe_queue {
private:
struct node {
std::shared_ptr<T> data;
std::unique_ptr<node> next;
};
std::mutex head_mutex;
std::mutex tail_mutex;
std::unique_ptr<node> head;
node* tail;
std::condition_variable data_cond;
node* get_tail() {
std::lock_guard<std::mutex> tail_lock(tail_mutex);
return tail;
}
std::unique_ptr<node> pop_head() {
std::unique_ptr<node> old_head = std::move(head);
head = std::move(old_head->next);
return old_head;
}
std::unique_lock<std::mutex> wait_for_data() {
std::unique_lock<std::mutex> head_lock(head_mutex);
data_cond.wait(head_lock, [&]() {
return head.get() != get_tail();
});
return std::move(head_lock);
}
std::unique_ptr<node> wait_pop_head() {
std::unique_lock<std::mutex> head_lock(wait_for_data());
return pop_head();
}
std::unique_ptr<node> wait_pop_head(T& value) {
std::unique_lock<std::mutex> head_lock(wait_for_data());
value = std::move(*head->data);
return pop_head();
}
std::unique_ptr<node> try_pop_head() {
std::lock_guard<std::mutex> head_lock(head_mutex);
if(head.get() == get_tail()) {
return std::unique_ptr<node>();
}
return pop_head();
}
std::unique_ptr<node> try_pop_head(T& value) {
std::lock_guard<std::mutex> head_lock(head_mutex);
if(head.get() == get_tail()) {
return std::unique_ptr<node>();
}
value = std::move(*head->data);
return pop_head();
}
public:
threadsafe_queue():
head(new node), tail(head.get()) {}
threadsafe_queue(threadsafe_queue const&) = delete;
threadsafe_queue& operator=(threadsafe_queue const&) = delete;
void push(T new_value) {
std::shared_ptr<T> new_data(std::make_shared<T>(std::move(new_value)));
std::unique_ptr<node> p (new node);
{
std::lock_guard<std::mutex> tail_lock(tail_mutex);
tail->data = new_data;
node* const new_tail = p.get();
tail->next = std::move(p);
tail = new_tail;
}
data_cond.notify_one();
}
std::shared_ptr<T> wait_and_pop() {
std::unique_ptr<node> const old_head = wait_pop_head();
return old_head->data;
}
void wait_and_pop(T& value) {
wait_pop_head(value);
}
bool empty() {
std::lock_guard<std::mutex> head_lock(head_mutex);
return (head.get() == get_tail());
}
std::shared_ptr<T> try_pop() {
std::unique_ptr<node> old_head = try_pop_head();
return old_head ? old_head->data: std::shared_ptr<T>();
}
bool try_pop(T& value) {
std::unique_ptr<node> old_head = try_pop_head(value);
return old_head != nullptr;
}
};
#endif
thread_pool.hpp
#ifndef _FREDRIC_THREAD_POOL_HPP_
#define _FREDRIC_THREAD_POOL_HPP_
#include "function_wrapper.hpp"
#include "thread_safe_queue.hpp"
#include "work_stealing_queue.hpp"
#include <thread>
#include <vector>
#include <atomic>
#include <functional>
#include <utility>
#include <future>
#include <queue>
#include <utility>
#include <functional>
#include <memory>
struct join_threads {
std::thread& operator[](int index) {
return threads[index];
}
void add_thread(std::thread&& thread) {
threads.emplace_back(std::move(thread));
}
~join_threads() {
for(std::thread& thread: threads) {
if(thread.joinable()) {
thread.join();
}
}
}
private:
std::vector<std::thread> threads;
};
class thread_pool {
typedef function_wrapper task_type;
std::atomic<bool> done;
threadsafe_queue<task_type> pool_work_queue;
std::vector<std::unique_ptr<work_stealing_queue>> queues;
join_threads joiner;
static thread_local work_stealing_queue* local_work_queue;
static thread_local unsigned my_index;
static unsigned thread_count;
void work_thread(unsigned my_index_) {
my_index = my_index_;
// 从work_steal_queue列表里面根据索引查找当前的local_work_queue
local_work_queue = queues[my_index].get();
while(!done) {
run_pending_task();
}
}
bool pop_task_from_local_queue(task_type& task) {
return local_work_queue && local_work_queue->try_pop(task);
}
bool pop_task_from_pool_queue(task_type& task) {
return pool_work_queue.try_pop(task);
}
bool pop_task_from_other_thread_queue(task_type& task) {
for(unsigned i=0; i<queues.size(); ++i) {
unsigned const index = (my_index + i +1) % queues.size();
// TODO: 这里有问题,等会再调, 要保证work_stealing_queue vector建完才行
if(queues.size() == thread_count && (queues[index]->try_steal(task))) {
return true;
}
}
return false;
}
public:
thread_pool():
done(false) {
thread_count = std::thread::hardware_concurrency();
try {
for(unsigned i=0; i<thread_count; ++i) {
// 一个线程一个work_stealing_queue队列
queues.push_back(std::unique_ptr<work_stealing_queue>(new work_stealing_queue));
joiner.add_thread(std::thread(&thread_pool::work_thread, this, i));
}
} catch(...) {
done = true;
throw;
}
}
~thread_pool() {
done = true;
}
template <typename FunctionType>
std::future<typename std::result_of<FunctionType()>::type> submit(FunctionType f) {
typedef typename std::result_of<FunctionType()>::type result_type;
std::packaged_task<result_type()> task(std::move(f));
std::future<result_type> res = task.get_future();
if(local_work_queue) {
local_work_queue->push(std::move(task));
} else {
pool_work_queue.push(std::move(task));
}
return res;
}
void run_pending_task() {
task_type task;
// 优先级,先拿 local_work_queue里面的,
// 再拿全局的pool_work_queue里面的,[前面两个都是拿队头最新的]
// 没有的话,再偷取其他队列队尾的 [后面是偷取队尾最旧的]
if(pop_task_from_local_queue(task) ||
pop_task_from_pool_queue(task) ||
pop_task_from_other_thread_queue(task)) {
task.call();
}else {
std::this_thread::yield();
}
}
};
thread_local work_stealing_queue* thread_pool::local_work_queue {nullptr};
thread_local unsigned thread_pool::my_index = 0;
unsigned thread_pool::thread_count = 0;
#endif
main.cpp
#include "thread_pool.hpp"
#include <iostream>
#include <algorithm>
#include <numeric>
#include <list>
#include <chrono>
template <typename T>
struct sorter {
thread_pool pool;
std::list<T> do_sort(std::list<T>& chunk_data) {
if(chunk_data.empty()) {
return chunk_data;
}
std::list<T> result;
result.splice(result.begin(), chunk_data, chunk_data.begin());
T const& partition_val = *result.begin();
typename std::list<T>::iterator divide_point = std::partition(
chunk_data.begin(), chunk_data.end(),
[&](T const& val) {
return val < partition_val;
}
);
std::list<T> new_lower_chunk;
new_lower_chunk.splice(new_lower_chunk.end(),
chunk_data, chunk_data.begin(), divide_point);
std::future<std::list<T>> new_lower = pool.submit([this,&new_lower_chunk]() {
return do_sort(new_lower_chunk);
});
std::list<T> new_higher(do_sort(chunk_data));
result.splice(result.end(), new_higher);
while(true) {
std::future_status status = new_lower.wait_for(std::chrono::seconds(0));
if(status == std::future_status::ready) {
break;
}
pool.run_pending_task();
}
result.splice(result.begin(), new_lower.get());
return result;
}
};
template <typename T>
std::list<T> parallel_quick_sort(std::list<T> input) {
if(input.empty()) {
return input;
}
sorter<T> s;
return s.do_sort(input);
}
int main(int argc, char* argv[]) {
std::list<int> ls;
for(std::size_t i=0; i<10000; ++i) {
if(i < 5000) {
ls.push_back(i + 5000);
} else {
ls.push_back(i-5000);
}
}
std::for_each(ls.begin(), ls.end(), [](int const& ele){
std::cout << ele << " ";
});
std::cout << std::endl;
std::list<int> res = parallel_quick_sort(ls);
std::cout << "After parallel_quick_sort, result: " << std::endl;
std::for_each(res.begin(), res.end(), [](int const& ele){
std::cout << ele << " ";
});
std::cout << std::endl;
return EXIT_SUCCESS;
}
程序输出如下,注意排序的list长度不能太大,这和thread_pool无关,主要是因为sorter仿函数里面用了递归,递归栈程度过深可能导致栈溢出。
image.png
网友评论