代码部分不多说,了解Java ThreadLocal的原理的一般都懂,我们这里重点关注下在C++ boost中是如何实现ThreadLocal的原理。
先上代码,
CMakeLists.txt
cmake_minimum_required(VERSION 2.6)
project(lexical_cast)
add_definitions(-std=c++14)
include_directories("/usr/local/include")
link_directories("/usr/local/lib")
file( GLOB APP_SOURCES ${CMAKE_CURRENT_SOURCE_DIR}/*.cpp)
foreach( sourcefile ${APP_SOURCES} )
file(RELATIVE_PATH filename ${CMAKE_CURRENT_SOURCE_DIR} ${sourcefile})
string(REPLACE ".cpp" "" file ${filename})
add_executable(${file} ${sourcefile})
target_link_libraries(${file} boost_filesystem boost_thread boost_system boost_serialization pthread boost_chrono)
endforeach( sourcefile ${APP_SOURCES} )
main.cpp
#include <boost/noncopyable.hpp>
#include <boost/thread/tss.hpp>
#include <boost/thread/thread.hpp>
#include <iostream>
#include <cassert>
class connection: boost::noncopyable {
public:
int open_count_;
connection(): open_count_(0) {}
void open();
void send_result(int result);
};
// 打开连接,简单的设置 open_count_ = 1
void connection::open() {
assert(!open_count_);
open_count_ = 1;
}
void connection::send_result(int /*result*/) {}
connection& get_connection();
// 类似Java中的 thread_local变量,一个线程一份
boost::thread_specific_ptr<connection> connection_ptr;
connection& get_connection() {
connection* p = connection_ptr.get();
// 还未初始化
if(!p) {
// 初始化连接对象
std::cerr << "开始初始化连接, 线程ID: " << boost::this_thread::get_id() << std::endl;
connection_ptr.reset(new connection());
p = connection_ptr.get();
p->open();
}
return *p;
}
void task() {
int result = 2;
get_connection().send_result(result);
}
void run_tasks() {
for(std::size_t i=0; i<10000000; ++i) {
task();
}
}
int main(int argc, char* argv[]) {
boost::thread t1;
boost::thread_group threads;
for(std::size_t i=0; i<4; ++i) {
threads.create_thread(&run_tasks);
}
threads.join_all();
return 0;
}
程序输出如下,一共会打印4个线程ID,
说明四个线程会生成四个thread_specific_ptr对象。
图片.png
我们再来看这个thread_specific_ptr的原理,
照例我们先来画下类图,
图片.png
然后先看下thread_specific_ptr类的实现,
boost::thread_specific_ptr类实现
template <typename T>
class thread_specific_ptr
{
private:
// 拷贝构造和赋值操作符私有化,不允许拷贝构造和赋值
thread_specific_ptr(thread_specific_ptr&);
thread_specific_ptr& operator=(thread_specific_ptr&);
// 定义一个清理函数类型的函数指针类型
// 其类型为 void (*) (T*)
typedef void(*original_cleanup_func_t)(T*);
// 定义一个范化的清理函数
static void default_deleter(T* data)
{
delete data;
}
// 定义一个清理函数调用器
// 这个调用器函数的功能相当简单,就是调用清理函数
static void cleanup_caller(detail::thread::cleanup_func_t cleanup_function,void* data)
{
reinterpret_cast<original_cleanup_func_t>(cleanup_function)(static_cast<T*>(data));
}
// 清理函数成员变量
detail::thread::cleanup_func_t cleanup;
public:
typedef T element_type;
// 默认构造器,使用默认的清理函数
thread_specific_ptr():
cleanup(reinterpret_cast<detail::thread::cleanup_func_t>(&default_deleter))
{}
// 也可以自定义清理函数进行构造,但是必须符合 void (*)(T*)
// 类型的描述符
explicit thread_specific_ptr(void (*func_)(T*))
: cleanup(reinterpret_cast<detail::thread::cleanup_func_t>(func_))
{}
// 析构函数,将当前线程绑定的 thread_data_base中的
// tss_data字段置为空
~thread_specific_ptr()
{
detail::set_tss_data(this,0,0,0,true);
}
// 获取当前线程对应的tss_data数据,
// 可以理解为thread_specific_data
// 将其转换为 T类型指针,返回
T* get() const
{
return static_cast<T*>(detail::get_tss_data(this));
}
// 指针操作符重载
T* operator->() const
{
return get();
}
// 解引用操作符重载
typename add_reference<T>::type operator*() const
{
return *get();
}
T* release()
{
T* const temp=get();
detail::set_tss_data(this,0,0,0,false);
return temp;
}
// reset
// 设置当前线程的thread_data_base数据中的tss_data字段为新值
void reset(T* new_value=0)
{
T* const current_value=get();
if(current_value!=new_value)
{
detail::set_tss_data(this,&cleanup_caller,cleanup,new_value,true);
}
}
};
}
所以说看到这里最重要的其实就是set_tss_data和get_tss_data两个全局函数的实现了。
但是这两个全局函数是线程库里面的,编译到了so库文件中,直接看是看不到的。
可以转到 boost源码的下载目录,例如,/home/fredric/software/boost_1_71_0/libs/thread/,然后用vscode打开当前目录,做一个全局搜索。
set_tss_data的实现,
void set_tss_data(void const* key,
detail::tss_data_node::cleanup_caller_t caller,
detail::tss_data_node::cleanup_func_t func,
void* tss_data,bool cleanup_existing)
{
if(tss_data_node* const current_node=find_tss_data(key))
{
if(cleanup_existing && current_node->func && (current_node->value!=0))
{
(*current_node->caller)(current_node->func,current_node->value);
}
if(func || (tss_data!=0))
{
current_node->caller=caller;
current_node->func=func;
current_node->value=tss_data;
}
else
{
erase_tss_node(key);
}
}
else if(func || (tss_data!=0))
{
add_new_tss_node(key,caller,func,tss_data);
}
}
}
void add_new_tss_node(void const* key,
detail::tss_data_node::cleanup_caller_t caller,
detail::tss_data_node::cleanup_func_t func,
void* tss_data)
{
detail::thread_data_base* const current_thread_data(get_or_make_current_thread_data());
current_thread_data->tss_data.insert(std::make_pair(key,tss_data_node(caller,func,tss_data)));
}
struct BOOST_THREAD_DECL thread_data_base:
enable_shared_from_this<thread_data_base>
{
thread_data_ptr self;
pthread_t thread_handle;
boost::mutex data_mutex;
boost::condition_variable done_condition;
bool done;
bool join_started;
bool joined;
boost::detail::thread_exit_callback_node* thread_exit_callbacks;
std::map<void const*,boost::detail::tss_data_node> tss_data;
......
};
struct tss_data_node
{
typedef void(*cleanup_func_t)(void*);
typedef void(*cleanup_caller_t)(cleanup_func_t, void*);
cleanup_caller_t caller;
cleanup_func_t func;
void* value;
tss_data_node(cleanup_caller_t caller_,cleanup_func_t func_,void* value_):
caller(caller_),func(func_),value(value_)
{}
};
// get_tss_data方法的实现
void* get_tss_data(void const* key)
{
if(tss_data_node* const current_node=find_tss_data(key))
{
return current_node->value;
}
return 0;
}
tss_data_node* find_tss_data(void const* key)
{
detail::thread_data_base* const current_thread_data(get_current_thread_data());
if(current_thread_data)
{
std::map<void const*,tss_data_node>::iterator current_node=
current_thread_data->tss_data.find(key);
if(current_node!=current_thread_data->tss_data.end())
{
return ¤t_node->second;
}
}
return 0;
}
可以看到第一set_tss_data的时候,会调用add_new_tss_node方法,把数据通过key->value pair的形式插入到 当前线程的thread_data_base的tss_data字段中。
调用get_tss_data方法其实就是根据当前的thread_specific_ptr对象来获取tss_data的值。
整个代码逻辑非常清晰而且简单。
网友评论