美文网首页
ThreadLocal源码分析及避坑实践

ThreadLocal源码分析及避坑实践

作者: 一路花开_8fab | 来源:发表于2018-09-15 17:18 被阅读0次

    ThreadLocal可以为每个线程保存一份变量的副本,防止在多线程情况下,属于某个线程的变量被其他线程修改。下面从源码角度分析其实现原理。
    观察最常使用的get()和set()方法可以看出:

    • get()操作需要获取当前线程对应的ThreadLocalMap,再根据当前ThreadLoca变量的引用,获取当前线程的变量副本。
    • set(T value)操作同样需要获取当前线程对应的ThreadLocalMap,再根据当前ThreadLoca变量的引用,设置当前线程的变量副本。
      public T get() {
            Thread t = Thread.currentThread();
            ThreadLocalMap map = getMap(t);
            if (map != null) {
                ThreadLocalMap.Entry e = map.getEntry(this);
                if (e != null) {
                    @SuppressWarnings("unchecked")
                    T result = (T)e.value;
                    return result;
                }
            }
            return setInitialValue();
        }
    
        public void set(T value) {
            Thread t = Thread.currentThread();
            ThreadLocalMap map = getMap(t);
            if (map != null)
                map.set(this, value);
            else
                createMap(t, value);
        }
    

    我们来看一下ThreadLocal、ThreadLocalMap和Thread的关系。ThreadLocalMap是ThreadLocal的静态内部类,而每一个Thread对象都包含一个ThreadLocalMap。

    public class ThreadLocal<T> {
        ......
       
        static class ThreadLocalMap {
            ......
        }
        ......
    }
    
    public class Thread implements Runnable {
    
        ......
        /* ThreadLocal values pertaining to this thread. This map is maintained
         * by the ThreadLocal class. */
        ThreadLocal.ThreadLocalMap threadLocals = null;
        ....
    }
    

    既然每个线程都维护一个ThreadLocalMap,那么为什么不设计Map<Thread,T>这种形式,一个线程对应一个存储对象,而“托管”给ThreadLocal来保存每个线程的变量副本呢?ThreadLocal这样设计的目的主要有两个:

    • 一是可以保证当前线程结束时相关对象能尽快被回收;
    • 二是ThreadLocalMap中的元素会大大减少,我们都知道map过大更容易造成哈希冲突而导致性能变差。

    下面我们着重看下ThreadLocalMap这个数据结构。

    static class ThreadLocalMap {
    
            static class Entry extends WeakReference<ThreadLocal<?>> {
                /** The value associated with this ThreadLocal. */
                Object value;
    
                Entry(ThreadLocal<?> k, Object v) {
                    super(k);
                    value = v;
                }
            }
        ......
    }
    

    ThreadLocalMap中的key是ThreadLocal<?>对象,value值当前线程的变量副本。这里需要注意的是,ThreadLoalMap的Entry是继承WeakReference,和HashMap很大的区别是,Entry中没有next字段,所以就不存在链表的情况了。那么ThreadLocalMap在set和get时是如何解决hash冲突的呢,接下来进行介绍。

    hash冲突

    private void set(ThreadLocal<?> key, Object value) {
    
                // We don't use a fast path as with get() because it is at
                // least as common to use set() to create new entries as
                // it is to replace existing ones, in which case, a fast
                // path would fail more often than not.
    
                Entry[] tab = table;
                int len = tab.length;
                int i = key.threadLocalHashCode & (len-1);
    
                for (Entry e = tab[i];
                     e != null;
                     e = tab[i = nextIndex(i, len)]) {
                    ThreadLocal<?> k = e.get();
    
                    if (k == key) {
                        e.value = value;
                        return;
                    }
    
                    if (k == null) {
                        replaceStaleEntry(key, value, i);
                        return;
                    }
                }
    
                tab[i] = new Entry(key, value);
                int sz = ++size;
                if (!cleanSomeSlots(i, sz) && sz >= threshold)
                    rehash();
            }
    

    在往ThreadLocalMap中put元素时,首先计算索引

    • 如果该索引出没有Entry,则退出循环,构造一个新的Entry插入;
    • 如果该索引处已插入Entry,并且对应的key正好为当前的ThreadLocal<?>对象,则直接进行value的替换,
    • 如果该索引处已插入Entry,并且对应的key不是当前的ThreadLocal<?>对象,则计算下一个索引。计算下一个索引的方式其实就是当前索引加1,若超过数组长度,则索引为0。
       /**
         * Increment i modulo len.
       */
       private static int nextIndex(int i, int len) {
           return ((i + 1 < len) ? i + 1 : 0);
       }
    

    在从ThreadLocalMap中get元素时,首先计算索引

    private Entry getEntry(ThreadLocal<?> key) {
                int i = key.threadLocalHashCode & (table.length - 1);
                Entry e = table[i];
                if (e != null && e.get() == key)
                    return e;
                else
                    return getEntryAfterMiss(key, i, e);
            }
    
    private Entry getEntryAfterMiss(ThreadLocal<?> key, int i, Entry e) {
                Entry[] tab = table;
                int len = tab.length;
    
                while (e != null) {
                    ThreadLocal<?> k = e.get();
                    if (k == key)
                        return e;
                    if (k == null)
                        expungeStaleEntry(i);
                    else
                        i = nextIndex(i, len);
                    e = tab[i];
                }
                return null;
            }
    
    • 若索引处有Entry,并且对应的key正好为当前的ThreadLocal<?>对象,则返回对应的value
    • 若索引处没有Entry,则按照与set方法相似的过程,计算下一个索引,直到找到某个Entry,对应的key正好为当前的ThreadLocal<?>对象,如果找不到,最终返回null

    常见的坑

    由于ThreadLocal其内部条目为弱引用,当key为null时,该条目就变成“废弃条目”,相关“value”的回收,往往依赖于几个关键点,即set、remove、rehash。下面是set示例:

    private void set(ThreadLocal<?> key, Object value) {
    
                Entry[] tab = table;
                int len = tab.length;
                int i = key.threadLocalHashCode & (len-1);
    
                for (Entry e = tab[i];
                     e != null;
                     e = tab[i = nextIndex(i, len)]) {
                    ThreadLocal<?> k = e.get();
    
                    if (k == key) {
                        e.value = value;
                        return;
                    }
    
                    if (k == null) {
                        // 替换废弃条目
                        replaceStaleEntry(key, value, i);
                        return;
                    }
                }
    
                tab[i] = new Entry(key, value);
                int sz = ++size;
                
                // 扫描并清理发现的废弃条目,并检查容量是否超限
                if (!cleanSomeSlots(i, sz) && sz >= threshold)
                    // 清理废弃条目,如果仍然超限,则扩容
                    rehash();
            }
    

    具体的清理逻辑是在cleanSomeSlots和expungeStaleEntry中。可以看出,废弃项目的回收依赖于显示的触发,否则就要等待线程结束,进而回收相应的ThreadLocalMap!这就是很多OOM的来源,所以通常建议:

    1. 应用一定要自己负责remove
    2. 不要和线程池配合,因为worker线程往往是不会退出的

    下面举一个例子说明,ThreadLocal在线程池中使用的坑
    使用SpringBoot创建一个Web应用程序,使用ThreadLocal存放一个Integer的值,来暂且代表需要在线程中保存的用户ID,这个值初始时null,在业务逻辑中,会把外部传入的用户ID设置到ThreadLocal中,示例代码如下

    @RestController
    public class WrongDemoController {
        private static final ThreadLocal<Integer> currentUser = ThreadLocal.withInitial(() -> null);
    
        @GetMapping("/wrong")
        public Map wrong(@RequestParam(value = "userId") Integer userId) {
            String before = Thread.currentThread().getName() + ":" + currentUser.get();
            currentUser.set(userId);
            String after = Thread.currentThread().getName() + ":" + currentUser.get();
            Map result = new HashMap();
            result.put("before", before);
            result.put("after", after);
            return result;
        }
    }
    

    线程池会重用固定的几个线程,为了更快地重现问题,在配置文件中设置一下tomcat的参数,把工作线程池最大线程数设置为1,这样始终是同一个线程在处理请求:

    server.tomcat.max-threads=1
    

    在浏览器中依次输入userId=1和userId=2,可以看出:

    • 当userId=1时,设置ThreadLocal之前和之后,从ThreadLocal中拿到的值分别为null和1
    • 当userId=2时,设置ThreadLocal之前和之后,从ThreadLocal中拿到的值分别为1和2


      image.png
      image.png

    问题出现了,为什么当userId=2时,从ThreadLocal拿到的初始值是1呢?原因是tomact的工作线程被重用了(在我们的例子中只有一个工作线程),那么很可能从ThreadLocal中拿到的值是别的用户的请求遗留的值(真实生产环境可能会导致用户信息错乱)。
    解决方案:
    ThreadLocal工具用来存放一些数据时,需要特别注意在代码运行完后,显示地去清空设置的数据。比如在上面的案例中,可以再finally代码块中显示清除ThreadLocal中的数据。

    @RestController
    public class WrongDemoController {
        private static final ThreadLocal<Integer> currentUser = ThreadLocal.withInitial(() -> null);
    
        @GetMapping("/wrong")
        public Map wrong(@RequestParam(value = "userId") Integer userId) {
            try{
                String before = Thread.currentThread().getName() + ":" + currentUser.get();
                currentUser.set(userId);
                String after = Thread.currentThread().getName() + ":" + currentUser.get();
                Map result = new HashMap();
                result.put("before", before);
                result.put("after", after);
                return result;
            }finally {
                // 显示清除ThreadLocal中的数据
                currentUser.remove();
            }
    
        }
    }
    

    相关文章

      网友评论

          本文标题:ThreadLocal源码分析及避坑实践

          本文链接:https://www.haomeiwen.com/subject/imqtnftx.html