美文网首页JavaJavaJava 程序员
java多线程——ThreadLocal那些不为人知的细节

java多线程——ThreadLocal那些不为人知的细节

作者: 马小莫QAQ | 来源:发表于2022-05-19 22:05 被阅读0次

    今天我们来剖析一下ThreadLocal的源码。

    说到ThreadLocal,我们在日常的开发工作中用的还是挺多的。

    比如,用户登录的时候我们可以通过ThreadLocal把用户的信息保存起来,而不用在每次使用的时候再去查一遍。

    Spring中的声明式事务也是通过ThreadLocal来保存数据库的链接,从而使多条SQL语句使用的是同一个数据库链接,保证事务。

    好了,话不多说,我们开始。

    引言

    首先看下ThreadLocal的整体结构

    在Thread类中保存了一个ThreadLocalMap的变量

    /* ThreadLocal values pertaining to this thread. This map is maintained
     * by the ThreadLocal class. */
    ThreadLocal.ThreadLocalMap threadLocals = null;
    

    ThreadLocalMapThreadLocal的内部类,底层数据结构是一个数组

    /**
     * The table, resized as necessary.
     * table.length MUST always be a power of two.
     */
    private Entry[] table;
    

    元素是Entry类,这个类又是ThreadLocalMap的内部类,继承了WeakReference

    static class Entry extends WeakReference<ThreadLocal<?>> {
        /** The value associated with this ThreadLocal. */
        Object value;
    
        Entry(ThreadLocal<?> k, Object v) {
            super(k);
            value = v;
        }
    }
    

    可以看到,Entry中k是弱引用,也就是ThreadLocal,而value仍然是强引用,我们通常所说的内存泄漏原因也就在这个地方,后面再说。

    好了,ThreadLocal的整体结构我们介绍完了,下面我们开始看他的核心方法。

    set

    public void set(T value) {
        Thread t = Thread.currentThread();
        ThreadLocalMap map = getMap(t);
        if (map != null)
            map.set(this, value);
        else
            createMap(t, value);
    }
    

    首先是set方法,我们在使用ThreadLocal的时候,肯定是先存然后再取,所以我们先看看他是怎么存的。

    首先调用getMap方法获取当前线程所保存的ThreadLocalMap

    ThreadLocalMap getMap(Thread t) {
        return t.threadLocals;
    }
    

    这个刚刚说过,Thread类里保存了一个ThreadLocalMap的变量

    如果为空先进行初始化

    void createMap(Thread t, T firstValue) {
        t.threadLocals = new ThreadLocalMap(this, firstValue);
    }
    

    调用ThreadLocalMap的构造方法

    ThreadLocalMap(ThreadLocal<?> firstKey, Object firstValue) {
        table = new Entry[INITIAL_CAPACITY];
        int i = firstKey.threadLocalHashCode & (INITIAL_CAPACITY - 1);
        table[i] = new Entry(firstKey, firstValue);
        size = 1;
        setThreshold(INITIAL_CAPACITY);
    }
    

    这里我们就可以看出,ThreadLocalMap的底层数据结构是数组

    首先构造了一个默认长度16Entry数组,然后计算数组下标

    private final int threadLocalHashCode = nextHashCode();
    
    private static AtomicInteger nextHashCode = new AtomicInteger();
    
    private static final int HASH_INCREMENT = 0x61c88647;
    
    private static int nextHashCode() {
        return nextHashCode.getAndAdd(HASH_INCREMENT);
    }
    

    ThreadLocal的hash值是一个叫threadLocalHashCode的变量,调用的是nextHashCode方法,这个方法又是调用一个AtomicInteger静态实例的getAndAdd方法。

    注意,这个nextHashCode变量是静态的,也就是说,每次新建一个ThreadLocal实例,他的hashcode都是在之前的基础上再加HASH_INCREMENT的。

    下面来看看HASH_INCREMENT这个变量,值是0x61c88647,转换成10进制就是1640531527

    看到这里,想必小伙伴们很自然的很有一个疑问,为什么每次hashcode都是在之前基础上再加一个这个值呢?

    我们先来看一个小实验

    public static void main(String[] args) {
        int a = 0x61c88647;
        int len = 16;
        for (int i = 1; i < len + 1; i++) {
            System.out.println(i + " " + ((a*i) & (len-1)));
        }
    }
    

    这段程序是模拟连续创建16个ThreadLocal实例,他的下标分布情况,我们看看结果如何

    居然没有一个下标重复的,再试下长度为32看看

    一样,没有一个下标重复,是不是很神奇

    这里面其实是蕴含了一些数学原理的,我们先看下这个数字是怎么来的

    把上面的公式变形一下,(long)((1<<31) * (Math.sqrt(5)-1)/2 * 2);

    (Math.sqrt(5)-1)/2这个值是什么?

    数字比较好的小伙伴可能立马就想到了,这不就是我们在初中学习的黄金分割吗,0.618

    所以,为什么每次hashcode递增1640531527,求出来的下标会均匀分布,原因就在这里,感兴趣的小伙伴可以去研究一下。

    我们继续往下看,在初始化完成之后会调用setThreshold方法设置扩容阈值

    private void setThreshold(int len) {
        threshold = len * 2 / 3;
    }
    

    这里的阈值和HashMap不太一样,HashMap是设置的3/4倍,他这里是2/3

    在第一次初始化之后,第二次调用的时候就会调用ThreadLocalMapset方法

    private void set(ThreadLocal<?> key, Object value) {
        Entry[] tab = table;
        int len = tab.length;
        int i = key.threadLocalHashCode & (len-1);
    
        for (Entry e = tab[i];
             e != null;
             e = tab[i = nextIndex(i, len)]) {
            ThreadLocal<?> k = e.get();
    
            if (k == key) {
                e.value = value;
                return;
            }
    
            if (k == null) {
                replaceStaleEntry(key, value, i);
                return;
            }
        }
    
        tab[i] = new Entry(key, value);
        int sz = ++size;
        if (!cleanSomeSlots(i, sz) && sz >= threshold)
            rehash();
    }
    

    HashMap一样,循环遍历数组,找出符合条件的key,nextIndex是获取数组的下一位

    private static int nextIndex(int i, int len) {
        return ((i + 1 < len) ? i + 1 : 0);
    }
    

    因为数组是有界的,所以当遍历超过数组范围时会重新回到0下标位。

    循环中有2个判断,第一个判断key是否相等,如果相等直接覆盖value值。

    第二个判断k是否为空,如果为空,替换当前数组位的值。

    这里注意了,当前索引位Entrykey为null,但是value是不为null的,这里说下前面提到的内存泄漏问题。

    在Java中引用分为4种,强软弱虚四大法王,强引用就是我们日常工作时用到的引用,比如User a = new User(),a就是强引用,软引用使用SoftReference包装,弱引用使用WeakReference包装,虚引用使用PhantomReference包装。

    虚引用一般是用来链接堆外对象,通过虚引用实现对堆外内存的回收弱引用每次发生GC的时候会被回收,而软引用只有在内存不足的时候才会被回收。

    ThreadLocalMap中Entry的key就是通过弱引用修饰的,所以每次发生GC时会被回收掉,导致key变成null,而value强引用,不会被回收,但是此时的value已经没有了任何意义,只是白白占着内存,所以也就导致了这部分内存不能被正常使用,造成内存泄漏

    好了,我们继续往下看。

    其实,从这里就可以看出,ThreadLocal处理hash碰撞是使用的线性探测法,就是如果计算出的索引位被别人占用了,那么就看下一位有没有被占用,一直找到没被占用的或者key为null的。

    看下他的替换方法replaceStaleEntry

    private void replaceStaleEntry(ThreadLocal<?> key, Object value,
                                           int staleSlot) {
        Entry[] tab = table;
        int len = tab.length;
        Entry e;
    
        int slotToExpunge = staleSlot;
        for (int i = prevIndex(staleSlot, len);
             (e = tab[i]) != null;
             i = prevIndex(i, len))
            if (e.get() == null)
                slotToExpunge = i;
    
        for (int i = nextIndex(staleSlot, len);
             (e = tab[i]) != null;
             i = nextIndex(i, len)) {
            ThreadLocal<?> k = e.get();
    
            if (k == key) {
                e.value = value;
    
                tab[i] = tab[staleSlot];
                tab[staleSlot] = e;
    
                // Start expunge at preceding stale entry if it exists
                if (slotToExpunge == staleSlot)
                    slotToExpunge = i;
                cleanSomeSlots(expungeStaleEntry(slotToExpunge), len);
                return;
            }
    
            if (k == null && slotToExpunge == staleSlot)
                slotToExpunge = i;
        }
    
        // If key not found, put new entry in stale slot
        tab[staleSlot].value = null;
        tab[staleSlot] = new Entry(key, value);
    
        // If there are any other stale entries in run, expunge them
        if (slotToExpunge != staleSlot)
            cleanSomeSlots(expungeStaleEntry(slotToExpunge), len);
    }
    

    这个方法有点长,我们一点点细看。

    先看第一个for循环,我们现在知道当前数组下标位的key是为null的,待会是要被回收的,那么,我顺带看看前面还有没有Entry的key是null的,如果有的话那我就一并回收了岂不是更好。所以,这个for的作用就是向前遍历,如果找到key==null的,记录下位置,赋值给slotToExpunge。当遇到Entry为null时停下来,否则一直向前遍历,遍历到第一个元素时,会跳到数组的末尾继续往前遍历。

    这里可能有小伙伴会想了,如果我一直没遇到Entry为null的,会不会又遍历回自己了?

    显然,是不会的。

    忘记了吗,当数组的元素个数达到一定的值时是会扩容的,所以,数组中始终会有一些下标位是为null的。

    再看第二个for循环,这次是向后开始遍历,如果找到满足条件的key,那么就覆盖value,将当前索引位元素和staleSlot索引位元素替换下,画个图理解一下

    因为staleSlot索引位key为null,待会要被清理掉,所以把他和覆盖完value值的i位替换下。然后判断之前向前遍历的时候有没有找到key为null的,如果没找到,就将开始清理的位置设置为i,否则从之前找到的索引位开始清理。

    ThreadLocal清理的方法有2个,先看里面那个,expungeStaleEntry方法

    private int expungeStaleEntry(int staleSlot) {
        Entry[] tab = table;
        int len = tab.length;
    
        // expunge entry at staleSlot
        tab[staleSlot].value = null;
        tab[staleSlot] = null;
        size--;
    
        // Rehash until we encounter null
        Entry e;
        int i;
        for (i = nextIndex(staleSlot, len);
             (e = tab[i]) != null;
             i = nextIndex(i, len)) {
            ThreadLocal<?> k = e.get();
            if (k == null) {
                e.value = null;
                tab[i] = null;
                size--;
            } else {
                int h = k.threadLocalHashCode & (len - 1);
                if (h != i) {
                    tab[i] = null;
    
                    // Unlike Knuth 6.4 Algorithm R, we must scan until
                    // null because multiple entries could have been stale.
                    while (tab[h] != null)
                        h = nextIndex(h, len);
                    tab[h] = e;
                }
            }
        }
        return i;
    }
    

    因为清理是从staleSlot开始,所以上来就把staleSlot位的元素清空了。

    然后向后遍历,遇到key为null的直接清空掉。

    如果不为null,就计算下标位,如果发现计算出来的下标位不是自己现在的位置,那么就说明当初set的时候,计算出来的索引位被占用了,被迫向后遍历了。

    那么,把当前i位设置为null

    为什么设置为null呢?

    因为此时已经清理掉了一些key为null的元素,当初占用他位置的元素此时很有可能被清理掉了,所以他要去夺回属于自己的东西,邪笑。

    紧接着,他就开始循环遍历,从计算出的h位开始寻找,一直找到空余的位置为止。

    最后返回i,注意了,这个i位元素是为null的。

    回到外层,再次进行一次清理,调用cleanSomeSlots方法

    private boolean cleanSomeSlots(int i, int n) {
        boolean removed = false;
        Entry[] tab = table;
        int len = tab.length;
        do {
            i = nextIndex(i, len);
            Entry e = tab[i];
            if (e != null && e.get() == null) {
                n = len;
                removed = true;
                i = expungeStaleEntry(i);
            }
        } while ( (n >>>= 1) != 0);
        return removed;
    }
    

    这里注意一下,因为当前i位元素是为null的,所以开始遍历的时候是从下一位开始遍历的,如果发现key为null的,再次调用之前的expungeStaleEntry方法开始清理。

    如果没找到key为null的,那么会循环log2^n次,找到了重新赋值n = len,再次循环log2^n次。

    再次回到外层方法,如果key不符合条件,那么判断key是否为null,为null再判断之前向前遍历的时候有没有发现key为null的Entry,没发现就设置开始清理的位置。

    一直遍历到元素为null,如果都没有找到符合条件的就跳出循环。新增一个Entry插入到staleSlot位置。因为之前循环的时候没有找到符合条件的key,没有进行清理工作,所以此时会进行清理工作。和之前循环调用的方法cleanSomeSlots(expungeStaleEntry(slotToExpunge), len)一样。

    回到一开始的set方法,如果循环中没有找到符合条件的key,也没有找到key为null的,那么就会构造一个Entry元素赋值到i位置上。

    一般新增一个元素后都会判断是否需要扩容,所以此时同样会判断扩容,但是扩容之前会进行一次随机清理,如果正巧清理了key为null的元素,那么因为清理了元素,所以数组个数减少了,也就不用再判断扩容了,如果没有清理到,此时判断是否超过阈值,超过了进行扩容。调用rehash方法

    private void rehash() {
        expungeStaleEntries();
    
        // Use lower threshold for doubling to avoid hysteresis
        if (size >= threshold - threshold / 4)
            resize();
    }
    

    在进行真正的扩容之前会把数组全部遍历一遍,清理key为null的元素,expungeStaleEntries这个方法

    private void expungeStaleEntries() {
        Entry[] tab = table;
        int len = tab.length;
        for (int j = 0; j < len; j++) {
            Entry e = tab[j];
            if (e != null && e.get() == null)
                expungeStaleEntry(j);
        }
    }
    

    可以看到,这里把数组从头到尾遍历了一遍,发现key为null的就调用expungeStaleEntry进行清理。

    清理之后判断是否超过阈值,这里把阈值减小了,减到原来的3/4

    这里作者可能考虑到,清理之后,如果元素数量还超过阈值的3/4,那么过不了多久肯定又会超过2/3,与其那个时候再扩容不如现在提前扩容算了。

    调用resize进行扩容

    private void resize() {
        Entry[] oldTab = table;
        int oldLen = oldTab.length;
        int newLen = oldLen * 2;
        Entry[] newTab = new Entry[newLen];
        int count = 0;
    
        for (int j = 0; j < oldLen; ++j) {
            Entry e = oldTab[j];
            if (e != null) {
                ThreadLocal<?> k = e.get();
                if (k == null) {
                    e.value = null; // Help the GC
                } else {
                    int h = k.threadLocalHashCode & (newLen - 1);
                    while (newTab[h] != null)
                        h = nextIndex(h, newLen);
                    newTab[h] = e;
                    count++;
                }
            }
        }
    
        setThreshold(newLen);
        size = count;
        table = newTab;
    }
    

    这个方法比较简单,就是数组容量扩大一倍,然后把老数组的元素转移到新数组上,

    到这里set方法我们就剖析完了,下面我们看get方法

    get

    public T get() {
        Thread t = Thread.currentThread();
        ThreadLocalMap map = getMap(t);
        if (map != null) {
            ThreadLocalMap.Entry e = map.getEntry(this);
            if (e != null) {
                @SuppressWarnings("unchecked")
                T result = (T)e.value;
                return result;
            }
        }
        return setInitialValue();
    }
    

    get方法相对而言简单一些,首先获取当前线程的ThreadLocalMap变量,如果为null,调用setInitialValue初始化

    private T setInitialValue() {
        T value = initialValue();
        Thread t = Thread.currentThread();
        ThreadLocalMap map = getMap(t);
        if (map != null)
            map.set(this, value);
        else
            createMap(t, value);
        return value;
    }
    

    initialValue方法返回的是个null,然后调用前面说的createMap方法进行初始化

    如果ThreadLocalMap变量不为null,调用getEntry获取Entry元素。

    private Entry getEntry(ThreadLocal<?> key) {
        int i = key.threadLocalHashCode & (table.length - 1);
        Entry e = table[i];
        if (e != null && e.get() == key)
            return e;
        else
            return getEntryAfterMiss(key, i, e);
    }
    

    这里如果计算出来的i索引位满足就返回,否则调用getEntryAfterMiss方法

    private Entry getEntryAfterMiss(ThreadLocal<?> key, int i, Entry e) {
        Entry[] tab = table;
        int len = tab.length;
    
        while (e != null) {
            ThreadLocal<?> k = e.get();
            if (k == key)
                return e;
            if (k == null)
                expungeStaleEntry(i);
            else
                i = nextIndex(i, len);
            e = tab[i];
        }
        return null;
    }
    

    如果e为null说明已经发生GC被回收掉了,返回null。否则,从i开始往后遍历,满足条件就返回,为null就清理,一直到e==null或者找到符合条件的为止。

    最后回到get方法判断有没有找到符合条件的Entry,找到就返回,没找到继续调用setInitialValue方法,将当前ThreadLocal实例作为key,null作为value,构造一个Entry插入到数组中。

    最后看下remove方法

    reove

    public void remove() {
       ThreadLocalMap m = getMap(Thread.currentThread());
       if (m != null)
           m.remove(this);
    }
    

    调用ThreadLocalMapremove方法

    private void remove(ThreadLocal<?> key) {
        Entry[] tab = table;
        int len = tab.length;
        int i = key.threadLocalHashCode & (len-1);
        for (Entry e = tab[i];
             e != null;
             e = tab[i = nextIndex(i, len)]) {
            if (e.get() == key) {
                e.clear();
                expungeStaleEntry(i);
                return;
            }
        }
    }
    

    计算出当前ThreadLocal实例所在的i索引位,判断此位置的key是否是自己,是,就删除,然后调用expungeStaleEntry方法看看能不能清理掉一些元素,然后返回。

    实战应用

    在我们日常的工作中,线上出现问题的话需要去排查,而现在微服务盛行,许多项目都由传统的单体式拆分成了微服务,经常客户端一个请求过来会经过好几个系统,这个时候为了追踪整个链路的调用情况,我们通常会创建一个traceId,贯穿整个调用链路,这样,我们在查日志的时候就可以通过这个traceId将整个调用过程串联起来。

    但是为了提高系统的快速响应能力,我们经常会创建线程池来进行异步执行,这个时候traceId就会断掉,如果恰巧是线程池执行出现了错误,那么就无法跟踪到了。

    这个时候ThreadLocal就派上用场了。

    在日志框架slf4j里有一个叫MDC的类,通过他就可以实现我们需要的功能。

    我们先看一下正常情况下调用的过程。

    /**
     * @author 程序员阿轩
     */
    @Component
    public class WebFilter extends GenericFilterBean {
        Logger logger = LoggerFactory.getLogger(WebFilter.class);
    
        @Override
        public void doFilter(ServletRequest servletRequest, ServletResponse servletResponse, FilterChain filterChain) {
            System.out.println("WebFilter doFilter-----------");
            try {
                HttpServletRequest request = (HttpServletRequest) servletRequest;
                String traceId = request.getHeader(TraceConstants.X_COMMON_TRACE_ID);
                if (StrUtil.isBlank(traceId)) {
                    traceId = TraceUtils.newTraceId();
                }
                TraceContext.setTraceId(traceId);
                System.out.println("WebFilter traceId ->" + traceId);
                filterChain.doFilter(servletRequest, servletResponse);
            } catch (Throwable e) {
    
            } finally {
                TraceContext.clear();
            }
        }
    }
    

    首先请求来到过滤器,我们在这里给他设置一个traceId

    /**
     * @author 程序员阿轩
     */
    public class TraceContext {
        private TraceContext() {
        }
    
        public static String getTraceId() {
            return MDC.get(TraceConstants.X_COMMON_TRACE_ID);
        }
    
        public static void setTraceId(String traceId) {
            MDC.put(TraceConstants.X_COMMON_TRACE_ID, traceId);
        }
    
        public static Map<String, String> getContextMap() {
            return MDC.getCopyOfContextMap();
        }
    
        public static void setContextMap(Map<String, String> contextMap) {
            if (contextMap == null) {
                contextMap = new HashMap<>();
            }
            MDC.setContextMap(contextMap);
        }
    
        public static void clear() {
            MDC.remove(TraceConstants.X_COMMON_TRACE_ID);
        }
    
        public static void clearAll() {
            MDC.clear();
        }
    }
    

    接着请求来到controller

    /**
     * @author 程序员阿轩
     */
    @RestController
    public class TraceController {
        @Autowired
        private TestService testService;
    
        @GetMapping("/trace")
        public String trace() {
            System.out.println("main->" + Thread.currentThread().getName());
            testService.test();
            return "程序员阿轩";
        }
    }
    

    service类

    /**
     * @author 程序员阿轩
     */
    @Service
    public class TestService {
        @Async("ecsAsyncExecutor")
        public void test() {
            System.out.println("线程池中线程->" + Thread.currentThread().getName() + "---" + MDC.get(TraceConstants.X_COMMON_TRACE_ID));
            try {
                Thread.sleep(500000);
            } catch (InterruptedException e) {
                e.printStackTrace();
            }
        }
    }
    

    线程池配置类

    /**
     * @author 程序员阿轩
     */
    @Configuration
    public class AsyncExecutorConfig implements AsyncConfigurer {
        private static final Logger LOGGER = LoggerFactory.getLogger(AsyncExecutorConfig.class);
    
        private final TaskExecutionProperties properties;
    
        public AsyncExecutorConfig(TaskExecutionProperties properties) {
            this.properties = properties;
        }
    
        @Override
        @Bean("ecsAsyncExecutor")
        public ThreadPoolTaskExecutor getAsyncExecutor() {
            ThreadPoolTaskExecutor executor = new ThreadPoolTaskExecutor() {
                @Override
                public <T> Future<T> submit(Callable<T> task) {
                    return super.submit(task);
                }
    
                @Override
                public void execute(Runnable task) {
                    super.execute(task);
                }
            };
            executor.setCorePoolSize(properties.getPool().getCoreSize());
            executor.setMaxPoolSize(properties.getPool().getMaxSize());
            executor.setQueueCapacity(properties.getPool().getQueueCapacity());
            executor.setAllowCoreThreadTimeOut(properties.getPool().isAllowCoreThreadTimeout());
            executor.setKeepAliveSeconds((int) properties.getPool().getKeepAlive().getSeconds());
            executor.setThreadNamePrefix(properties.getThreadNamePrefix());
            executor.setRejectedExecutionHandler(new ThreadPoolExecutor.CallerRunsPolicy());
            executor.initialize();
            return executor;
        }
    }
    

    这里主要为了演示,一些异常异常捕捉什么的就省掉了。

    yaml配置

    spring:
      task:
        execution:
          pool:
            allow-core-thread-timeout: true
            core-size: 1
            max-size: 5
            queue-capacity: 3
            keep-alive: 60s
          thread-name-prefix: a-xuan
    

    运行程序,打印结果

    WebFilter doFilter-----------
    WebFilter traceId ->08dd98a2-7343-4695-b509-2103bda6f7ef
    main->http-nio-9050-exec-1
    线程池中线程->a-xuan1---null
    

    可以看到,线程池中的线程获取traceId为null,获取不到。

    我们稍微改造下线程池的配置类

    /**
     * @author 程序员阿轩
     */
    @Configuration
    public class AsyncExecutorConfig implements AsyncConfigurer {
        private static final Logger LOGGER = LoggerFactory.getLogger(AsyncExecutorConfig.class);
    
        private final TaskExecutionProperties properties;
    
        public AsyncExecutorConfig(TaskExecutionProperties properties) {
            this.properties = properties;
        }
    
        @Override
        @Bean("ecsAsyncExecutor")
        public ThreadPoolTaskExecutor getAsyncExecutor() {
            ThreadPoolTaskExecutor executor = new ThreadPoolTaskExecutor() {
                @Override
                public <T> Future<T> submit(Callable<T> task) {
                    return super.submit(ThreadMdcUtil.wrap(task, MDC.getCopyOfContextMap()));
                }
    
                @Override
                public void execute(Runnable task) {
                    super.execute(ThreadMdcUtil.wrap(task, MDC.getCopyOfContextMap()));
                }
            };
            executor.setCorePoolSize(properties.getPool().getCoreSize());
            executor.setMaxPoolSize(properties.getPool().getMaxSize());
            executor.setQueueCapacity(properties.getPool().getQueueCapacity());
            executor.setAllowCoreThreadTimeOut(properties.getPool().isAllowCoreThreadTimeout());
            executor.setKeepAliveSeconds((int) properties.getPool().getKeepAlive().getSeconds());
            executor.setThreadNamePrefix(properties.getThreadNamePrefix());
            executor.setRejectedExecutionHandler(new ThreadPoolExecutor.CallerRunsPolicy());
            executor.initialize();
            return executor;
        }
    }
    

    我们把需要执行的任务包装一层

    /**
     * @author 程序员阿轩
     */
    public class ThreadMdcUtil {
        public static <T> Callable<T> wrap(final Callable<T> callable, final Map<String, String> context) {
            return new Callable<T>() {
                @Override
                public T call() throws Exception {
                    if (context == null) {
                        MDC.clear();
                    } else {
                        MDC.setContextMap(context);
                    }
                    System.out.println("wrap: " + Thread.currentThread().getName() + "---" + MDC.get(TraceConstants.X_COMMON_TRACE_ID));
    
                    try {
                        return callable.call();
                    } finally {
                        MDC.clear();
                    }
                }
            };
    
        public static Runnable wrap(final Runnable runnable, final Map<String, String> context) {
            return () -> {
                if (context == null) {
                    MDC.clear();
                } else {
                    MDC.setContextMap(context);
                }
    //            System.out.println("wrap: " + Thread.currentThread().getName() + "---" + MDC.get(TraceConstants.X_COMMON_TRACE_ID));
    
                try {
                    runnable.run();
                } finally {
                    MDC.clear();
                }
            };
        }
    }
    

    再次执行看下打印结果

    WebFilter doFilter-----------
    WebFilter traceId ->655795e7-a58e-408b-9758-e65c04aa4e4a
    main->http-nio-9050-exec-1
    submit->http-nio-9050-exec-1
    wrap: a-xuan1---655795e7-a58e-408b-9758-e65c04aa4e4a
    线程池中线程->a-xuan1---655795e7-a58e-408b-9758-e65c04aa4e4a
    

    可以看到此时线程池中的线程拿到了traceId,从而完成了链路追踪的功能。

    下面我们简单看下MDC是怎么实现这个功能的。

    我们看下刚刚使用到的put和get方法

    public static void put(String key, String val) throws IllegalArgumentException {
        if (key == null) {
            throw new IllegalArgumentException("key parameter cannot be null");
        } else if (mdcAdapter == null) {
            throw new IllegalStateException("MDCAdapter cannot be null. See also http://www.slf4j.org/codes.html#null_MDCA");
        } else {
            mdcAdapter.put(key, val);
        }
    }
    
    public static String get(String key) throws IllegalArgumentException {
        if (key == null) {
            throw new IllegalArgumentException("key parameter cannot be null");
        } else if (mdcAdapter == null) {
            throw new IllegalStateException("MDCAdapter cannot be null. See also http://www.slf4j.org/codes.html#null_MDCA");
        } else {
            return mdcAdapter.get(key);
        }
    }
    

    可以看到,MDC只是个门面,真正发挥作用的是MDCAdapter这个东西。

    public interface MDCAdapter {
        void put(String var1, String var2);
    
        String get(String var1);
    
        void remove(String var1);
    
        void clear();
    
        Map<String, String> getCopyOfContextMap();
    
        void setContextMap(Map<String, String> var1);
    }
    

    MDCAdapter实际上是一个接口,功能由他的子类来实现

    现在我们日志框架通常使用的都是LogBack,我们看下LogBack的实现

    public void put(String key, String val) throws IllegalArgumentException {
        if (key == null) {
            throw new IllegalArgumentException("key cannot be null");
        } else {
            Map<String, String> oldMap = (Map)this.copyOnThreadLocal.get();
            Integer lastOp = this.getAndSetLastOperation(1);
            if (!this.wasLastOpReadOrNull(lastOp) && oldMap != null) {
                oldMap.put(key, val);
            } else {
                Map<String, String> newMap = this.duplicateAndInsertNewMap(oldMap);
                newMap.put(key, val);
            }
    
        }
    }
    

    可以看到,核心是一个Map的变量copyOnThreadLocal,从名字其实已经能够看出来了

    final ThreadLocal<Map<String, String>> copyOnThreadLocal = new ThreadLocal();
    private static final int WRITE_OPERATION = 1;
    private static final int MAP_COPY_OPERATION = 2;
    final ThreadLocal<Integer> lastOperation = new ThreadLocal();
    

    没错,他就是一个ThreadLocal,所有的一切都是围绕着这个ThreadLocal来进行的。

    总结

    本篇文章从ThreadLocal的源码剖析说到他在实际工作中的使用,其实小伙伴们可以发现,很多技术的底层都是我们熟悉的东西,只不过经过了层层包装,穿上了各种各样华丽的马甲之后,我们不认识他了,但是当你一步步去深究,像洋葱一样一层一层剥开他的时候,最后,你会情不自禁的感叹一句,哦---,原来如此,soga。

    作者:枫林晚
    链接:https://juejin.cn/post/7098951677121658917
    来源:稀土掘金

    相关文章

      网友评论

        本文标题:java多线程——ThreadLocal那些不为人知的细节

        本文链接:https://www.haomeiwen.com/subject/sslxprtx.html