美文网首页java源码学习
ThreadLocal、ThreadLocalMap源码分析

ThreadLocal、ThreadLocalMap源码分析

作者: 慕北人 | 来源:发表于2020-03-30 21:41 被阅读0次

    ThreadLocal源码学习

    ThreadLocal的工作过程更像是一个工具人,其核心代码set、get等都是通过ThreadLocalMap实现的,ThreadLocal只是作为这个Map中的key。所以我们看的顺序先从ThreadLocalMap看起。

    一、ThreadLocalMap

    1. ThreadLocalMap.Entry

    该类代表ThreadLocalMap中数据保存的形式:

    static class Entry extends WeakReference<ThreadLocal<?>> {
            /** The value associated with this ThreadLocal. */
            Object value;
    
            Entry(ThreadLocal<?> k, Object v) {
                super(k);
                value = v;
            }
        }  
    

    这里有一点要注意,在Entry中,ThreadLocal类型的key被封装成了一个虚引用类型WeakReference,这里的原因我们最后会总结。

    2. 重要属性、构造器和几个简单方法

    • static final int INITIAL_CAPACITY:初始的容量,值为16;
    • Entry[] table:真正存放键值对的数组
    • int size:table中键值对的数量
    • int threshold:需要resize时的下一个size

    TreadLocalMap中有两个构造器,但是在我们使用ThreadLocal的过程中,被调用的只有下面这个:

    ThreadLocalMap(ThreadLocal<?> firstKey, Object firstValue) {
            table = new Entry[INITIAL_CAPACITY];
            int i = firstKey.threadLocalHashCode & (INITIAL_CAPACITY - 1);
            table[i] = new Entry(firstKey, firstValue);
            size = 1;
            setThreshold(INITIAL_CAPACITY);
        }      
    

    构造器中无非就是对一些成员进行初始化,没有什么特殊之处。

    在上面的构造器中调用了setThreshold方法

    private void setThreshold(int len) {
            threshold = len * 2 / 3;
    }  
    

    可见,threshold的值被设为了参数长度的三分之二倍。

    private static int nextIndex(int i, int len) {
            return ((i + 1 < len) ? i + 1 : 0);
        }  
    
    private static int prevIndex(int i, int len) {
            return ((i - 1 >= 0) ? i - 1 : len - 1);
        }  
    

    nextIndex方法的效果相当于从初始的i每次都往len位置前进一步;prevIndex方法的效果相当于从初始的i每次都往0靠近一步。

    3. 核心方法

    getEntry
    private Entry getEntry(ThreadLocal<?> key) {
            int i = key.threadLocalHashCode & (table.length - 1);
            Entry e = table[i];
            if (e != null && e.get() == key)
                return e;
            else
                return getEntryAfterMiss(key, i, e);
        }  
    

    该方法计算table数组下标的方式和HashMap中是一样的,也是进行位运算,那么这里table的length也必定是2的幂,值得缀的是这里判断key是否是同一个key的时候使用的是“连等”运算符。当该位置不是要找的元素时,会调用getEntryAfterMiss方法

    getEntryAfterMiss方法
    private Entry getEntryAfterMiss(ThreadLocal<?> key, int i, Entry e) {
            Entry[] tab = table;
            int len = tab.length;
    
            while (e != null) {
                ThreadLocal<?> k = e.get();
                if (k == key)
                    return e;
                if (k == null)
                    expungeStaleEntry(i);
                else
                    i = nextIndex(i, len);
                e = tab[i];
            }
            return null;
        }    
    

    首先,这个while循环中,很好理解如果k就是key的话直接返回这个Entry。对于k为null的情况,需要注意这里有一点需要关注,在while循环体的大前提是该位置的节点Entry不是null,但是会有Entry不为null而k为null的情况(这就是将key包裹成一个WeakRefrence的结果;在后面我们把这种Entry节点成为陈旧Entry),这里会针对这种情况调用expungeStaleEntry方法进行处理(下文会分析)。还有一点需要理解的是,这个while循环,如果一直正常下去的话,那么i的变化是从初始位置一直到table的末尾后再从0开始,所以这个i导致这个while的循环有一种语义:直到下一个Entry为null的位置

    set方法
    private void set(ThreadLocal<?> key, Object value) {
            Entry[] tab = table;
            int len = tab.length;
            int i = key.threadLocalHashCode & (len-1);
    
            for (Entry e = tab[i];
                 e != null;
                 e = tab[i = nextIndex(i, len)]) {
                ThreadLocal<?> k = e.get();
    
                if (k == key) {
                    e.value = value;
                    return;
                }
    
                if (k == null) {
                    replaceStaleEntry(key, value, i);
                    return;
                }
            }
    
            ------------- 重点 1 ---------------
            tab[i] = new Entry(key, value);
            int sz = ++size;
            if (!cleanSomeSlots(i, sz) && sz >= threshold)
                rehash();
        }    
    

    我们来分析一下这个方法:

    1. 这个for循环的意思和之前的一样,找到下一个Entry不为null的位置
      1. 如果找到了key对应的点,那么直接替换value
      2. 如果找到了一个k为null的点,说明该位置可以被重用,那么调用replace替换掉(该方法下文会分析)
    2. 代码能够执行到重点1处的情况代表:table知道下一个为null的节点之间没有位置可以供新的值插入,那么这时候我们把新的节点插入到这个为null的位置
    3. 最后,产生了2中的情况说明当前的hash冲突有点严重了,所以通过cleanSomeSlots来清理陈旧的节点以及rehash重新为table中的元素排序
    remove方法
    private void remove(ThreadLocal<?> key) {
            Entry[] tab = table;
            int len = tab.length;
            int i = key.threadLocalHashCode & (len-1);
            for (Entry e = tab[i];
                 e != null;
                 e = tab[i = nextIndex(i, len)]) {
                if (e.get() == key) {
                    e.clear();
                    expungeStaleEntry(i);
                    return;
                }
            }
        }  
    

    该方法的意思还挺明显的,要注意的是他在调用e.clear()方法之后又调用了expungeStaleEntry()方法用来剔除陈旧的节点。

    4. 重中之重

    1. expungeStaleEntry方法

    在上面的getEntryAfterMiss方法和remove方法中都有该方法的回调,而且回调的位置,也就是传递的参数对应的意义是,该位置的Entry为陈旧Entry(即Entry不为null但是key为null的节点)

    private int expungeStaleEntry(int staleSlot) {
            Entry[] tab = table;
            int len = tab.length;
    
            -------- 重点 1 -----------
            tab[staleSlot].value = null;
            tab[staleSlot] = null;
            size--;
    
            // Rehash until we encounter null
            Entry e;
            int i;
    
            ------------ 重点 2 -------------
            for (i = nextIndex(staleSlot, len);
                 (e = tab[i]) != null;
                 i = nextIndex(i, len)) {
                ThreadLocal<?> k = e.get();
                if (k == null) {
                    e.value = null;
                    tab[i] = null;
                    size--;
                } else {
    
                    --------- 重点 3 -----------
                    int h = k.threadLocalHashCode & (len - 1);
                    if (h != i) {
                        tab[i] = null;
    
                        while (tab[h] != null)
                            h = nextIndex(h, len);
                        tab[h] = e;
                    }
                }
            }
            return i;
        }
    
    • 重点1:这里的代码直接将该节点以及value都设置为了null
    • 重点2:这个for循环我们太熟悉了,意思就是直到下一个Entry为null的位置
      • 首先if语句的作用把遇到的所有的陈旧节点都置为null,即都给清理掉
      • 重点3:这里的else语句中首先获取了一个普通的、正常的节点,然后判断其期望的hash下标与其真实的hash下标是否符合
        • 如果不符合则说明该节点是之前遇到了hash冲突的情况,那么在我们前面清理了很多陈旧节点的情况下,该节点期望的hash下标可能已经空了出来,如此下面的一个while循环就为其寻找一个合适的位置

    整个下来,该方法的作用就是,清除一部分陈旧的节点,rehash一部分普通的节点

    2. cleanSomeSlots

    在set方法中我们调用了这个方法,传递的参数的意义是:i代表之前这个位置为null的一个位置,只不过在传过来时已经被设置为了新加入的节点;n代表的是新的size

    private boolean cleanSomeSlots(int i, int n) {
            boolean removed = false;
            Entry[] tab = table;
            int len = tab.length;
            do {
                i = nextIndex(i, len);
                Entry e = tab[i];
                if (e != null && e.get() == null) {
                    n = len;
                    removed = true;
                    i = expungeStaleEntry(i);
                }
            } while ( (n >>>= 1) != 0);
            return removed;
        }  
    

    注意while循环的条件,这里循环会遍历log(n)个节点,并不会遍历所有的节点。而expungeStaleEntry方法我们上面分析过,所以cleanSomeSlots方法的作用是遍历log(n)个节点,将其中的陈旧节点都清除。

    3. rehash()

    private void rehash() {
            expungeStaleEntries();
    
            // Use lower threshold for doubling to avoid hysteresis
            if (size >= threshold - threshold / 4)
                resize();
        }    
    
    private void expungeStaleEntries() {
            Entry[] tab = table;
            int len = tab.length;
            for (int j = 0; j < len; j++) {
                Entry e = tab[j];
                if (e != null && e.get() == null)
                    expungeStaleEntry(j);
            }
        }
    

    之前分析expungeStaleEntry方法时我还有点纳闷,怎么在expungeStaleEntry里面就涉及到rehash的操作了,原来rehash方法回调的expungeStaleEntries方法中就是在每一个陈旧节点调用expungeStaleEntry方法

    注意到一点,当rehash方法执行完expungeStaleEntries方法之后,此时table数组中的元素都是非陈旧节点和null节点,那么如果此时数据的数量size满足size大于等于threshold的四分之三的haul,为了避免hash冲突,就需要进行扩容操作。

    4. resize方法

    private void resize() {
            Entry[] oldTab = table;
            int oldLen = oldTab.length;
            int newLen = oldLen * 2;
            Entry[] newTab = new Entry[newLen];
            int count = 0;
    
            for (int j = 0; j < oldLen; ++j) {
                Entry e = oldTab[j];
                if (e != null) {
                    ThreadLocal<?> k = e.get();
                    if (k == null) {
                        e.value = null; // Help the GC
                    } else {
                        int h = k.threadLocalHashCode & (newLen - 1);
                        while (newTab[h] != null)
                            h = nextIndex(h, newLen);
                        newTab[h] = e;
                        count++;
                    }
                }
            }
    
            setThreshold(newLen);
            size = count;
            table = newTab;
        }  
    

    可见,就是将每一个节点重新计算其hash值之后放入新的table数组中;而且每次扩容都是扩容成原来的二倍。

    值得注意的是,setThreshold方法只在这里和构造器中调用过;而setThreshold方法的实现为:

    private void setThreshold(int len) {
            threshold = len * 2 / 3;
        }  
    

    也就是说,在ThreaLocalMap中,threshold的值为table长度的三分之二,而扩容触发的条件是数据的个数达到threshold的四分之三,也就是数据的个数达到了table容量的一半就需要扩容

    小结

    1. 为何要把key设置为虚引用类型?

    到这里我们再来说一说为何要讲key设置为WeakRefrence,不同于HashMap,ThreadLocalMap的key为ThreadLocal对象,也就是说在我们使用的过程中,这个ThreadLocal对象有可能为null,从而被GC回收掉,这就表明我们之后再也没法访问该ThreadLocal为key的value数据了,那么为了节省空间,该位置理应让出来共他人使用,但是如果我们的key不是虚引用类型的话,那么永远在ThreadLocalMap.Entry中会有一个key持有ThreadLocal的强引用,导致该ThreadLocal的内存无法释放从而造成内存泄漏。

    但是如果使用虚引用类型的话,当ThreadLocal原来强引用类型的变量被赋值为null,等到GC到来的时候,那么该ThreadLocalMap.Entry中key对这个ThreadLocal的引用也会被回收,这就产生了一个陈旧节点,而key为null这一标志也可以作为我们判断该节点是否需要回收让给其他ThreadLocalMap.Entry的判断依据。

    2. 既然这种数据结构可以在Key为null的时候回收从而节省空间,那么为啥不推广到HashMap中?

    产生这种优势的前提是因为ThreadLocalMap在处理hash冲突的时候采用的是开放地址方法,即如果当前hash位置被其他的节点占用了,那么会从数组往后找空闲的位置;而HashMap中解决hash冲突采用的方案为使用链表或者红黑树,这样一来,获取每一个节点就显得有些耗时,从而导致ThreadLocalMap中的rehash方法在HashMap中实现效果不会很好。

    二、ThreadLocal

    我们分析了ThreadLocalMap的实现,那么ThreadLocal就非常简单了。

    public T get() {
        Thread t = Thread.currentThread();
        ThreadLocalMap map = getMap(t);
        if (map != null) {
            ThreadLocalMap.Entry e = map.getEntry(this);
            if (e != null) {
                @SuppressWarnings("unchecked")
                T result = (T)e.value;
                return result;
            }
        }
        return setInitialValue();
    }  
    
    public void set(T value) {
        Thread t = Thread.currentThread();
        ThreadLocalMap map = getMap(t);
        if (map != null)
            map.set(this, value);
        else
            createMap(t, value);
    }  
    
    public void remove() {
         ThreadLocalMap m = getMap(Thread.currentThread());
         if (m != null)
             m.remove(this);
     }  
    

    可见,它的方法都是调用的ThreadLocalMap的方法实现的,但是ThreadLocalMap的方法是getMap方法得到的:

    ThreadLocalMap getMap(Thread t) {
        return t.threadLocals;
    }  
    

    我们发现原来线程自己有一个ThreadLocalMap成员,经过查找,发现Thread自己没有一个地方去为该成员进行初始化操作,而真正的初始化操作就发生在ThreadLocal第一次set的时候调用createMap方法:

     void createMap(Thread t, T firstValue) {
        t.threadLocals = new ThreadLocalMap(this, firstValue);
    }  
    

    至此,ThreadLocal源码的学习就到此结束了~

    相关文章

      网友评论

        本文标题:ThreadLocal、ThreadLocalMap源码分析

        本文链接:https://www.haomeiwen.com/subject/ipmbuhtx.html