美文网首页
ThreadLocal 源码分析

ThreadLocal 源码分析

作者: 枫叶栈 | 来源:发表于2017-12-04 19:43 被阅读0次

    一、概述

    是不是觉得它是一个线程?不要被名字迷惑,它并不是一个线程。

    《从源码理解Android Handler消息机制》一文中,我们提到ThreadLocal,当时我们这么解释:ThreadLocal 你可以理解为保存一个在线程范围内可见的变量。那么ThreadLocal是如何做到的呢?Follow Me ,看看源码如何实现的。

    二、源码分析

    平常我们使用ThreadLocal都是调用其set()和get()方法,基于这两个方法为切入点我们来分析下它的实现原理。

    老规矩,源码是最好的解释,直接上源码:

    代码 1.1
        public void set(T value) {
            Thread t = Thread.currentThread();//获取当前调用的线程
            ThreadLocalMap map = getMap(t);//往下面看
            if (map != null)
                map.set(this, value);//直接往map添加数据 查看代码1.3
            else
                createMap(t, value);//查看代码1.2
        }
    
        ThreadLocalMap getMap(Thread t) {
             //直接返回线程的一个变量 我们发现是 ThreadLocal.ThreadLocalMap threadLocals = null;
            return t.threadLocals;
        }
    
    static class ThreadLocalMap {}//名字叫Map 并没有实现Map接口
    

    上边的set()方法里主要内容:

    1. 获取线程的ThreadLocalMap threadLocals 对象;
    2. 根据threadLocals 是否为空来决定是创建ThreadLocalMap 还是往ThreadLocalMap 添加对象;

    下边看下createMap方法:

    代码 1.2
     void createMap(Thread t, T firstValue) {
            t.threadLocals = new ThreadLocalMap(this, firstValue);//生成ThreadLocalMap
        }
    
      ThreadLocalMap(ThreadLocal firstKey, Object firstValue) {//ThreadLocal为key
             table = new Entry[INITIAL_CAPACITY];
            //注意ThreadLocal的threadLocalHashCode 
             int i = firstKey.threadLocalHashCode & (INITIAL_CAPACITY - 1);
             table[i] = new Entry(firstKey, firstValue);
             size = 1;
             setThreshold(INITIAL_CAPACITY);
        }
    
            /**
             * The initial capacity -- MUST be a power of two.
             */
            private static final int INITIAL_CAPACITY = 16;//必须是二的幂
    
    

    createMap()方法里主要做了:

    1. 生成ThreadLocalMap实例;
    2. 用ThreadLocal作为key,然后生成一个节点放入数组,至于数组位置,则由ThreadLocal的threadLocalHashCode&(INITIAL_CAPACITY -1)决定;
    3. INITIAL_CAPACITY 这个值必须是2的幂,初始为16;
    4. 神奇的 0x61c88647 ,每当我们new一个ThreadLocal对象,新对象的threadLocalHashCode值等于在静态变量nextHashCode变量上加 0x61c88647,至于原因看下边的数据测试:
    public class ThreadLocal<T> {
     
        private final int threadLocalHashCode = nextHashCode();
    
        private static AtomicInteger nextHashCode = new AtomicInteger();//原子变量  通过CAS操作更新
    
        private static final int HASH_INCREMENT = 0x61c88647;
    
        private static int nextHashCode() {
            return nextHashCode.getAndAdd(HASH_INCREMENT);
        }
    
        //我们看下 0x61c88647如何神奇:
      public static void main(String[] args) {
           int hashCode = 0x61c88647;
            
           System.out.println("数组length 为 16 ");
           for(int i =0;i<16;i++){
               System.out.print((15&(i*hashCode))+"  ");
           }
           
            System.out.println("");
            System.out.println("数组length 为 32 ");
           
            for(int i =0;i<32;i++){
                System.out.print((31&(i*hashCode))+"  ");
            }
        }
    运行结果:
    数组length 为 16 
    0  7  14  5  12  3  10  1  8  15  6  13  4  11  2  9  
    数组length 为 32 
    0  7  14  21  28  3  10  17  24  31  6  13  20  27  2  9  16  23  30  5  12  19  26  1  8  15  22  29  4  11  18  25  
    
    结果很神奇,这个跟数学相关,我也不是很清楚为什么,总之运行结果是散列的分散在数组中。
    

    接下来我们看下ThreadLocalMap 的 set()方法:

    代码1.3
    
            private void set(ThreadLocal key, Object value) {
    
                // We don't use a fast path as with get() because it is at
                // least as common to use set() to create new entries as
                // it is to replace existing ones, in which case, a fast
                // path would fail more often than not.
    
                Entry[] tab = table;
                int len = tab.length;
                int i = key.threadLocalHashCode & (len-1);//有没有很熟悉
    
                for (Entry e = tab[i];
                     e != null;
                     e = tab[i = nextIndex(i, len)]) {//线性探测法
                    ThreadLocal k = e.get();
    
                    if (k == key) {//key相同 替换值
                        e.value = value;
                        return;
                    }
                     //Entry 集成自 WeakReference<ThreadLocal>  k很有可能为null
                    if (k == null) {
                        replaceStaleEntry(key, value, i);//查看代码1.4
                        return;
                    }
                }
    
                tab[i] = new Entry(key, value);//表里没数据  生成节点加入进去
                int sz = ++size;//更改当前size
                if (!cleanSomeSlots(i, sz) && sz >= threshold)//判断是否触发阈值 触发则扩容
                    rehash();//查看代码1.7
            }
    
    

    主要用线性探测法向数组中确定节点位置,与HashMap的链地址法实现方式不一样。

    代码1.4
       //replaceStaleEntry方法  
            private void replaceStaleEntry(ThreadLocal key, Object value,
                                           int staleSlot) {
                Entry[] tab = table;
                int len = tab.length;
                Entry e;
    
                int slotToExpunge = staleSlot;//记录删除节点位置起始位置 
                for (int i = prevIndex(staleSlot, len);
                     (e = tab[i]) != null;
                     i = prevIndex(i, len))//从数组往前找 有节点但节点无key值则更新slotToExpunge ,否则停止查找
                    if (e.get() == null)
                        slotToExpunge = i;
    
                for (int i = nextIndex(staleSlot, len);
                     (e = tab[i]) != null;
                     i = nextIndex(i, len)) {//线性探索查找key相同节点
                    ThreadLocal k = e.get();
    
                    if (k == key) {//如果 k == key 则更新value 讲该节点更新到 staleSlot位置上 
                        e.value = value;
    
                        tab[i] = tab[staleSlot];
                        tab[staleSlot] = e;
    
                        // Start expunge at preceding stale entry if it exists
                        if (slotToExpunge == staleSlot)
                            slotToExpunge = i;
                       //清除部分节点expungeStaleEntry()查看代码1.5 cleanSomeSlots()查看代码1.6
                        cleanSomeSlots(expungeStaleEntry(slotToExpunge), len);
                        return;
                    }
    
                    if (k == null && slotToExpunge == staleSlot)
                        slotToExpunge = i;
                }
    
                // If key not found, put new entry in stale slot
                tab[staleSlot].value = null;
                tab[staleSlot] = new Entry(key, value);//未找到 生成新节点放入
    
                // If there are any other stale entries in run, expunge them
                if (slotToExpunge != staleSlot)
                    cleanSomeSlots(expungeStaleEntry(slotToExpunge), len);
            }
    
    代码1.5 方法expungeStaleEntry()
           // 从删除节点到后边遍历 到第一个为 null节点之间的节点都经过检测 返回第一个null节点位置
            private int expungeStaleEntry(int staleSlot) {
                Entry[] tab = table;
                int len = tab.length;
    
                // expunge entry at staleSlot
                tab[staleSlot].value = null;
                tab[staleSlot] = null;//删除该位置节点
                size--;
    
                // Rehash until we encounter null
                Entry e;
                int i;
                for (i = nextIndex(staleSlot, len);
                     (e = tab[i]) != null;
                     i = nextIndex(i, len)) {//从当前节点现象探索遍历
                    ThreadLocal k = e.get();
                    if (k == null) {//删除 key为null节点 
                        e.value = null;
                        tab[i] = null;
                        size--;
                    } else {
                        int h = k.threadLocalHashCode & (len - 1);//根据数组长度计算出新的位置 因为数组有可能扩容
                        if (h != i) {//不相等的情况需要改变该节点位置
                            tab[i] = null;
    
                            // Unlike Knuth 6.4 Algorithm R, we must scan until
                            // null because multiple entries could have been stale.
                            while (tab[h] != null)//讲h位置节点进行线性探测法确定位置
                                h = nextIndex(h, len);
                            tab[h] = e;//讲e节点更新到h位置
                        }
                    }
                }
                return i;
            }
    
    代码 1.6 方法cleanSomeSlots() 用于清除线性探索方向上的空节点
            private boolean cleanSomeSlots(int i, int n) {
                boolean removed = false;
                Entry[] tab = table;
                int len = tab.length;
                do {
                    i = nextIndex(i, len);
                    Entry e = tab[i];
                    if (e != null && e.get() == null) {
                        n = len;
                        removed = true;
                        i = expungeStaleEntry(i);// 见 2.2.2 
                    }
                } while ( (n >>>= 1) != 0);//n = n>>>1 无符号右移动并赋值 这边每次除以2有点不太理解 欢迎大家讨论
                return removed;
            }
    
    代码 1.7
            private void rehash() {
                expungeStaleEntries(); //见下边
    
                // Use lower threshold for doubling to avoid hysteresis
                if (size >= threshold - threshold / 4)
                    resize();//见1.8
            }
    
            /**
             * Expunge all stale entries in the table.
             */
            private void expungeStaleEntries() {
                Entry[] tab = table;
                int len = tab.length;
                for (int j = 0; j < len; j++) {
                    Entry e = tab[j];
                    if (e != null && e.get() == null)
                        expungeStaleEntry(j);//见1.5
                }
            }
        }
    
    代码1.8 resize()
            private void resize() {
                Entry[] oldTab = table;
                int oldLen = oldTab.length;
                int newLen = oldLen * 2;//扩容 容量依然是2的幂
                Entry[] newTab = new Entry[newLen];//生成新的数组
                int count = 0;
    
                for (int j = 0; j < oldLen; ++j) {//遍历旧的数组
                    Entry e = oldTab[j];
                    if (e != null) {
                        ThreadLocal k = e.get();
                        if (k == null) {
                            e.value = null; // Help the GC
                        } else {
                            int h = k.threadLocalHashCode & (newLen - 1);//计算在新数组中的位置
                            while (newTab[h] != null)// 冲突后 线性探测法
                                h = nextIndex(h, newLen);
                            newTab[h] = e;
                            count++;
                        }
                    }
                }
                //更新属性
                setThreshold(newLen);
                size = count;
                table = newTab;
            }
    

    上边就是ThreadLocal中set()方法的实现,主要: 向数组中插入节点,根据key (ThreadLocal)的threadLocalHashCode&(len-1)决定位置,然后根据线性探索法解决冲突问题,包括如果数组size超过阈值则扩容。

    下边分析下get()方法:

    代码2.1
        public T get() {
            Thread t = Thread.currentThread();
            ThreadLocalMap map = getMap(t);
            if (map != null) {
                ThreadLocalMap.Entry e = map.getEntry(this);//查看2.2
                if (e != null)
                    return (T)e.value;
            }
            return setInitialValue();//这是一个空方法,如果未命中则调用用该方法返回的默认value
        }
    
    代码2.2
            private Entry getEntry(ThreadLocal key) {
                int i = key.threadLocalHashCode & (table.length - 1);
                Entry e = table[i];
                if (e != null && e.get() == key)
                    return e;//命中
                else
                    return getEntryAfterMiss(key, i, e);//未命中 见下方
            }
    
            private Entry getEntryAfterMiss(ThreadLocal key, int i, Entry e) {
                Entry[] tab = table;
                int len = tab.length;
    
                while (e != null) {//线性探索查找
                    ThreadLocal k = e.get();
                    if (k == key)
                        return e;
                    if (k == null)
                        expungeStaleEntry(i);
                    else
                        i = nextIndex(i, len);
                    e = tab[i];
                }
                return null;//未找到返回null
            }
    

    通过get()我们可以看出:

    1. 根据key (ThreadLocal)的threadLocalHashCode&(len-1)位置的值是否命中,命中返回,没有命中则根据线性探索法查找节点;
    2. 第一步没找到则调用setInitialValue()方法返回值来充当返回值,该方法用户可以重写;

    下边看下remove()方法

    代码3.1
         public void remove() {
             ThreadLocalMap m = getMap(Thread.currentThread());
             if (m != null)
                 m.remove(this);// 见下方
         }
    
            private void remove(ThreadLocal key) {
                Entry[] tab = table;
                int len = tab.length;
                int i = key.threadLocalHashCode & (len-1);//熟不熟悉 
                for (Entry e = tab[i];
                     e != null;
                     e = tab[i = nextIndex(i, len)]) {//线性探索法查找
                    if (e.get() == key) {
                        e.clear();
                        expungeStaleEntry(i);//见1.5
                        return;
                    }
                }
            }
    

    三、总结

    上边就是为大家分析的ThreadLocal的实现,主要实现依靠:

    1. 每个线程保留一个ThreadLocalMap 变量;
    2. 当我们向ThreadLocal中放入值的时候,其实我们是将值放入到了Thread的threadLocals中;
    3. 没当我们实例一个ThreadLocal的时候,该实例的threadLocalHashCode值会改变,ThreadLocalMap中的table数组长度记为len,则不同实例的threadLocalHashCode&(len-1)会散列在table数组的不同位置;
    4. ThreadLocalMap中table属性中的Entry继承自WeakReference<ThreadLocal>,所以key很容易被回收;
    5. 当出现hash冲突时,是使用线性探索法查找,不同于HashMap的查找原理;

    以上就是为大家分享的ThreadLocal源码分析。感谢你的耐心阅读,如有错误,欢迎指正。如果本文对你有帮助,记得点赞。欢迎关注我的微信公众号:


    qrcode_for_gh_84a02a29fedd_430.jpg

    相关文章

      网友评论

          本文标题:ThreadLocal 源码分析

          本文链接:https://www.haomeiwen.com/subject/mnztixtx.html