美文网首页
boost::thread_specific_ptr对象的原理及

boost::thread_specific_ptr对象的原理及

作者: FredricZhu | 来源:发表于2021-04-07 05:20 被阅读0次

    代码部分不多说,了解Java ThreadLocal的原理的一般都懂,我们这里重点关注下在C++ boost中是如何实现ThreadLocal的原理。
    先上代码,
    CMakeLists.txt

    cmake_minimum_required(VERSION 2.6)
    project(lexical_cast)
    
    add_definitions(-std=c++14)
    
    include_directories("/usr/local/include")
    link_directories("/usr/local/lib")
    file( GLOB APP_SOURCES ${CMAKE_CURRENT_SOURCE_DIR}/*.cpp)
    foreach( sourcefile ${APP_SOURCES} )
        file(RELATIVE_PATH filename ${CMAKE_CURRENT_SOURCE_DIR} ${sourcefile})
        string(REPLACE ".cpp" "" file ${filename})
        add_executable(${file} ${sourcefile})
        target_link_libraries(${file} boost_filesystem boost_thread boost_system boost_serialization pthread boost_chrono)
    endforeach( sourcefile ${APP_SOURCES} )
    

    main.cpp

    #include <boost/noncopyable.hpp>
    #include <boost/thread/tss.hpp>
    #include <boost/thread/thread.hpp>
    
    #include <iostream>
    
    #include <cassert>
    
    class connection: boost::noncopyable {
    public:
        int open_count_;
        connection(): open_count_(0) {}
    
        void open();
        void send_result(int result);
    };
    
    // 打开连接,简单的设置 open_count_ = 1
    void connection::open() {
        assert(!open_count_);
        open_count_ = 1;
    }
    
    void connection::send_result(int /*result*/) {}
    
    connection& get_connection();
    
    // 类似Java中的 thread_local变量,一个线程一份
    boost::thread_specific_ptr<connection> connection_ptr;
    
    connection& get_connection() {
        connection* p = connection_ptr.get();
        // 还未初始化
        if(!p) {
            // 初始化连接对象
            std::cerr << "开始初始化连接, 线程ID: " << boost::this_thread::get_id() << std::endl;
            connection_ptr.reset(new connection());
            p = connection_ptr.get();
            p->open();
        }
        return *p;
    }
    
    void task() {
        int result = 2;
        get_connection().send_result(result);
    }
    
    void run_tasks() {
        for(std::size_t i=0; i<10000000; ++i) {
            task();
        }
    }
    
    int main(int argc, char* argv[]) {
    
        boost::thread t1;
        boost::thread_group threads;
        for(std::size_t i=0; i<4; ++i) {
            threads.create_thread(&run_tasks);
        }
        threads.join_all();
        return 0;
    }
    

    程序输出如下,一共会打印4个线程ID,
    说明四个线程会生成四个thread_specific_ptr对象。


    图片.png

    我们再来看这个thread_specific_ptr的原理,
    照例我们先来画下类图,


    图片.png

    然后先看下thread_specific_ptr类的实现,
    boost::thread_specific_ptr类实现

    
        template <typename T>
        class thread_specific_ptr
        {
        private:
            // 拷贝构造和赋值操作符私有化,不允许拷贝构造和赋值 
            thread_specific_ptr(thread_specific_ptr&);
            thread_specific_ptr& operator=(thread_specific_ptr&);
            // 定义一个清理函数类型的函数指针类型
           //  其类型为 void (*) (T*)
            typedef void(*original_cleanup_func_t)(T*);
          
            // 定义一个范化的清理函数
            static void default_deleter(T* data)
            {
                delete data;
            }
            // 定义一个清理函数调用器
            // 这个调用器函数的功能相当简单,就是调用清理函数
            static void cleanup_caller(detail::thread::cleanup_func_t cleanup_function,void* data)
            {
                reinterpret_cast<original_cleanup_func_t>(cleanup_function)(static_cast<T*>(data));
            }
    
             // 清理函数成员变量
            detail::thread::cleanup_func_t cleanup;
    
        public:
            typedef T element_type;
             
            // 默认构造器,使用默认的清理函数
            thread_specific_ptr():
                cleanup(reinterpret_cast<detail::thread::cleanup_func_t>(&default_deleter))
            {}
    
            // 也可以自定义清理函数进行构造,但是必须符合 void (*)(T*)
            // 类型的描述符 
            explicit thread_specific_ptr(void (*func_)(T*))
              : cleanup(reinterpret_cast<detail::thread::cleanup_func_t>(func_))
            {}
           // 析构函数,将当前线程绑定的 thread_data_base中的
           //  tss_data字段置为空
            ~thread_specific_ptr()
            {
                detail::set_tss_data(this,0,0,0,true);
            }
            // 获取当前线程对应的tss_data数据,
           // 可以理解为thread_specific_data
           // 将其转换为 T类型指针,返回
            T* get() const
            {
                return static_cast<T*>(detail::get_tss_data(this));
            }
            // 指针操作符重载
            T* operator->() const
            {
                return get();
            }
            // 解引用操作符重载
            typename add_reference<T>::type operator*() const
            {
                return *get();
            }
            T* release()
            {
                T* const temp=get();
                detail::set_tss_data(this,0,0,0,false);
                return temp;
            }
            // reset 
            // 设置当前线程的thread_data_base数据中的tss_data字段为新值
            void reset(T* new_value=0)
            {
                T* const current_value=get();
                if(current_value!=new_value)
                {
                    detail::set_tss_data(this,&cleanup_caller,cleanup,new_value,true);
                }
            }
        };
    }
    

    所以说看到这里最重要的其实就是set_tss_data和get_tss_data两个全局函数的实现了。
    但是这两个全局函数是线程库里面的,编译到了so库文件中,直接看是看不到的。
    可以转到 boost源码的下载目录,例如,/home/fredric/software/boost_1_71_0/libs/thread/,然后用vscode打开当前目录,做一个全局搜索。
    set_tss_data的实现,

     void set_tss_data(void const* key,
                              detail::tss_data_node::cleanup_caller_t caller,
                              detail::tss_data_node::cleanup_func_t func,
                              void* tss_data,bool cleanup_existing)
            {
                if(tss_data_node* const current_node=find_tss_data(key))
                {
                    if(cleanup_existing && current_node->func && (current_node->value!=0))
                    {
                        (*current_node->caller)(current_node->func,current_node->value);
                    }
                    if(func || (tss_data!=0))
                    {
                        current_node->caller=caller;
                        current_node->func=func;
                        current_node->value=tss_data;
                    }
                    else
                    {
                        erase_tss_node(key);
                    }
                }
                else if(func || (tss_data!=0))
                {
                    add_new_tss_node(key,caller,func,tss_data);
                }
            }
        }
    
          void add_new_tss_node(void const* key,
                                  detail::tss_data_node::cleanup_caller_t caller,
                                  detail::tss_data_node::cleanup_func_t func,
                                  void* tss_data)
            {
                detail::thread_data_base* const current_thread_data(get_or_make_current_thread_data());
                current_thread_data->tss_data.insert(std::make_pair(key,tss_data_node(caller,func,tss_data)));
            }
    
      struct BOOST_THREAD_DECL thread_data_base:
                enable_shared_from_this<thread_data_base>
            {
                thread_data_ptr self;
                pthread_t thread_handle;
                boost::mutex data_mutex;
                boost::condition_variable done_condition;
                bool done;
                bool join_started;
                bool joined;
                boost::detail::thread_exit_callback_node* thread_exit_callbacks;
                std::map<void const*,boost::detail::tss_data_node> tss_data;
      ......
         };
         struct tss_data_node
            {
                typedef void(*cleanup_func_t)(void*);
                typedef void(*cleanup_caller_t)(cleanup_func_t, void*);
    
                cleanup_caller_t caller;
                cleanup_func_t func;
                void* value;
    
                tss_data_node(cleanup_caller_t caller_,cleanup_func_t func_,void* value_):
                    caller(caller_),func(func_),value(value_)
                {}
            };
    
          
         // get_tss_data方法的实现
          void* get_tss_data(void const* key)
            {
                if(tss_data_node* const current_node=find_tss_data(key))
                {
                    return current_node->value;
                }
                return 0;
            }
       
         tss_data_node* find_tss_data(void const* key)
            {
                detail::thread_data_base* const current_thread_data(get_current_thread_data());
                if(current_thread_data)
                {
                    std::map<void const*,tss_data_node>::iterator current_node=
                        current_thread_data->tss_data.find(key);
                    if(current_node!=current_thread_data->tss_data.end())
                    {
                        return &current_node->second;
                    }
                }
                return 0;
            }
    

    可以看到第一set_tss_data的时候,会调用add_new_tss_node方法,把数据通过key->value pair的形式插入到 当前线程的thread_data_base的tss_data字段中。
    调用get_tss_data方法其实就是根据当前的thread_specific_ptr对象来获取tss_data的值。
    整个代码逻辑非常清晰而且简单。

    相关文章

      网友评论

          本文标题:boost::thread_specific_ptr对象的原理及

          本文链接:https://www.haomeiwen.com/subject/gynskltx.html