美文网首页
ThreadLocal分析

ThreadLocal分析

作者: 雨之都 | 来源:发表于2021-07-09 15:49 被阅读0次

    上一次看ThreadLocal的源代码已经是很久之前的事情了,今早突然想起发现自己连ThreadLocal的原理一点也想不起了,因此重新再读一次源码,分析一下ThreadLocal的原理

    ThreadLocal正如其名(线程本地)这是指对象设置或者获取的值都是当前线程访问的,其他线程设置和访问的不是同一个对象(当前前提是initialValue和setValue使用姿势正确).诸如数据库的连接对象就可以使用ThreadLocal来保存.下面就可以展开分析了

    对于ThreadLocal来说,公开的函数就是

    • set(T) void
    • get(): T
    • remove(): void

    通常ThreadLocal的使用姿势有,直接构造一个ThreadLocal对象,然后调用set 设置值, 调用get获取值,这种情况是对于ThreadLocal没有初值的情况,因此如果我们在调用get之前没有调用set.那么第一次获取的值就是空的,对于这种情况,ThreadLocal提供了一个保护方法

    • initialValue(): T

    子类通过覆写这个方法,使得ThreadLocal在get的时候第一次能够获取到初始值,ThreadLocal的一个静态函数

    • withInitial(Supplier<? extends S>): ThreadLocal<S> 就是这个原理

    构造函数

    public ThreadLocal() {
        }
    

    可以看到ThreadLocal的构造函数是空的,平平无奇,ThreadLocal的魔法应该是在set和get的时候会发生的,构造函数应该也不需要做什么特别的工作

    get函数

    public T get() {
            Thread t = Thread.currentThread();
            ThreadLocalMap map = getMap(t);
            if (map != null) {
                ThreadLocalMap.Entry e = map.getEntry(this);
                if (e != null) {
                    @SuppressWarnings("unchecked")
                    T result = (T)e.value;
                    return result;
                }
            }
            return setInitialValue();
        }
    

    根据源代码,查找主要分为如下几个步骤

    1. 获取当前Thread
    2. 根据当前Thread关联的ThreadLocalMap
    3. 如果不为空,调用其查找函数
    4. 如果为空的化,那么调用setInitialValue设置初值并且返回,setInitialValue和set函数的流程基本查不到,所以这里不赘述,后文分析完set函数之后,基本也就明白了它的功能了
    5. 因此我们对不为空的查找函数,深入去了解一下
    private Entry getEntry(ThreadLocal<?> key) {
                            // 通过当前ThreadLocal的hashKey获取目标位置
                int i = key.threadLocalHashCode & (table.length - 1);
                Entry e = table[i];
                            // 如果目标位置的元素不为空且key相同,那么就查找完毕,返回该Entry
                if (e != null && e.get() == key)
                    return e;
                else
                    return getEntryAfterMiss(key, i, e);
            }
    

    可以看到当第一次在目标位置没有找到的时候,会调用getEntryAfterMiss函数,我们看一下该函数的实现;可以看到就是往后线性遍历,一直到Entry为空,未找到则直接返回为空,我们注意到中间有一步是当获取的key为空的情况(当然可能! 因为key是ThreadLocal通过弱引用的方式保存的,如果ThreadLocal被销毁了,那么key就是为空了)

    private Entry getEntryAfterMiss(ThreadLocal<?> key, int i, Entry e) {
                Entry[] tab = table;
                int len = tab.length;
    
                while (e != null) {
                    ThreadLocal<?> k = e.get();
                    if (k == key)
                        return e;
                    if (k == null)
                        expungeStaleEntry(i);
                    else
                        i = nextIndex(i, len);
                    e = tab[i];
                }
                return null;
            }
    

    垃圾回收可能发生在任何时间,所以当key无效的时候,我们应该做清理工作,我个人理解的清理工作是遍历从i到后面所有的已经过期了的,将这些移除,并且对于之后的元素rehash.重新插入队列,那么我们看一下代码观察看是否做了这样的事情

    private int expungeStaleEntry(int staleSlot) {
                Entry[] tab = table;
                int len = tab.length;
    
                // expunge entry at staleSlot
                tab[staleSlot].value = null;
                tab[staleSlot] = null;
                size--;
    
                // Rehash until we encounter null
                Entry e;
                int i;
                for (i = nextIndex(staleSlot, len);
                     (e = tab[i]) != null;
                     i = nextIndex(i, len)) {
                    ThreadLocal<?> k = e.get();
                    if (k == null) {
                        e.value = null;
                        tab[i] = null;
                        size--;
                    } else {
                                            // 不为空的元素就重新插入
                        int h = k.threadLocalHashCode & (len - 1);
                        if (h != i) {
                            tab[i] = null;
    
                            // Unlike Knuth 6.4 Algorithm R, we must scan until
                            // null because multiple entries could have been stale.
                            while (tab[h] != null)
                                h = nextIndex(h, len);
                            tab[h] = e;
                        }
                    }
                }
                            // 下一个可以插入的为null的slot
                return i;
            }
    

    看代码确实就是做了这样的事情,所以插入实际上就是查找+线性遍历

    set函数

    public void set(T value) {
            // 获取当前的线程
        Thread t = Thread.currentThread();
        ThreadLocalMap map = getMap(t);
        if (map != null)
            map.set(this, value);
        else
            createMap(t, value);
    }
    

    看到第二步调用了getMap,并且把当前线程传入了,那么这里做了什么?展开看看

    ThreadLocalMap getMap(Thread t) {
            return t.threadLocals;
        }
    

    发现只是返回了线程的这个成员,那如果我们根本就没有设置过值的化,那么这个值理所当然是空的,因此会走到下面的createMap函数,继续往下跟

    void createMap(Thread t, T firstValue) {
            t.threadLocals = new ThreadLocalMap(this, firstValue);
        }
    

    可以看到直接创建了ThreadLocalMap对象,并且把this和firstValue当成构造函数的参数传入了

    那在分析流程继续之前,有必要看一下ThreadLocalMap的源码,看ThreadLocalMap的注释说,ThreadLocalMap是一个hashmap,是专门为了维护线程本地数据而造出的一个数据结构,因此它没有暴露出任何方法,但为了让Thread能够访问,所以ThreadLocalMap本身是包访问权限的

    之前研究HashMap的时候发现,HashMap无非是Entry的数组+链表,那ThreadLocalMap肯定也不例外,看一下它的Entry长什么样子

    static class Entry extends WeakReference<ThreadLocal<?>> {
                /** The value associated with this ThreadLocal. */
                Object value;
    
                Entry(ThreadLocal<?> k, Object v) {
                    super(k);
                    value = v;
                }
            }
    

    将ThreadLocal作为k,并且key传给了父类的构造函数,且因为父类是WeakReference.所以Entry的key是弱引用的,接着来看一下ThreadLocalMap的构造函数

    ThreadLocalMap(ThreadLocal<?> firstKey, Object firstValue) {
                table = new Entry[INITIAL_CAPACITY];
                int i = firstKey.threadLocalHashCode & (INITIAL_CAPACITY - 1);
                table[i] = new Entry(firstKey, firstValue);
                size = 1;
                setThreshold(INITIAL_CAPACITY);
            }
    
    
    1. 根据INITIAL_CAPACITY构造了一个Entry的数组table

    2. 第二行代码计算了将要插入表中的位置

      • 里面访问了threadLocalHashCode的属性,看一下threadLocalHashCode长什么样子
      private final int threadLocalHashCode = nextHashCode();
      
      • 那nextHashCode做了什么呢,继续往下翻
          private static final int HASH_INCREMENT = 0x61c88647;
          /**
           * Returns the next hash code.
           */
          private static int nextHashCode() {
              return nextHashCode.getAndAdd(HASH_INCREMENT);
          }
      

      nextHashCode是一个原子类型的数据,每次调用这个方法都加上了一个HASH_INCREMENT,这个数字的具体原理没有深入研究,谷歌了一下发现通过这样的方式能够减少碰撞,暂且不表,

      • 通过第二步获取的nextHashCode和位的大小减1 进行位于,找到了元素被防止的防止,构造一个Entry.将其放入数组

      最后一步调用了了setThreshold方法设置了一下阈值,和hashmap的阈值是等同的

      private void setThreshold(int len) {
                  threshold = len * 2 / 3;
              }
      

      可以看到阈值是长度的2/3。

      如果我们创建了第二个ThreadLocal.同样调用设值。

      假设这个时候相应的ThreadLocalMap已经创建好了,那么就会走到If中的map.set(this, value)中去,看一下set方法长什么样子

      private void set(ThreadLocal<?> key, Object value) {
      
                  // We don't use a fast path as with get() because it is at
                  // least as common to use set() to create new entries as
                  // it is to replace existing ones, in which case, a fast
                  // path would fail more often than not.
      
                  Entry[] tab = table;
                  int len = tab.length;
                              // 通过threadLocalHashCode获取要插的下一个点,每一个ThreadLocal对象的
                              // threadLocalHashCode都不一致
                  int i = key.threadLocalHashCode & (len-1);
                              // 线性探测法去避免冲突
                  for (Entry e = tab[i];
                       e != null;
                       e = tab[i = nextIndex(i, len)]) {
                                      // 获取当前entry对应的key
                      ThreadLocal<?> k = e.get();
                                      // 如果相等就直接替换
                      if (k == key) {
                          e.value = value;
                          return;
                      }
                                      // 如果为空;说明当前的ThreadLocal对象被回收了;那么执行替换
                      if (k == null) {
                          replaceStaleEntry(key, value, i);
                          return;
                      }
                  }
      
                  tab[i] = new Entry(key, value);
                  int sz = ++size;
                              // 如果cleanSomwSlots没有清理移除元素,并且下面已经超过threshold了;那么需要执行rehash
                  if (!cleanSomeSlots(i, sz) && sz >= threshold)
                      rehash();
              }
      
                      private static int nextIndex(int i, int len) {
                  return ((i + 1 < len) ? i + 1 : 0);
              }
      
                      private static int prevIndex(int i, int len) {
                  return ((i - 1 >= 0) ? i - 1 : len - 1);
              }
      
                      
      private void replaceStaleEntry(ThreadLocal<?> key, Object value,
                                             int staleSlot) {
                  Entry[] tab = table;
                  int len = tab.length;
                  Entry e;
      
                  // Back up to check for prior stale entry in current run.
                  // We clean out whole runs at a time to avoid continual
                  // incremental rehashing due to garbage collector freeing
                  // up refs in bunches (i.e., whenever the collector runs).
                  int slotToExpunge = staleSlot;
                  for (int i = prevIndex(staleSlot, len);
                       (e = tab[i]) != null;
                       i = prevIndex(i, len))
                      if (e.get() == null)
                          slotToExpunge = i;
      
                  // Find either the key or trailing null slot of run, whichever
                  // occurs first
                  for (int i = nextIndex(staleSlot, len);
                       (e = tab[i]) != null;
                       i = nextIndex(i, len)) {
                      ThreadLocal<?> k = e.get();
      
                      // If we find key, then we need to swap it
                      // with the stale entry to maintain hash table order.
                      // The newly stale slot, or any other stale slot
                      // encountered above it, can then be sent to expungeStaleEntry
                      // to remove or rehash all of the other entries in run.
                      if (k == key) {
                          e.value = value;
                                              // 把当前的数值和过期的slot交换;这里必须要交换;否则就破坏了插入的原则;可能会导致之后查找失败
                          tab[i] = tab[staleSlot];
                          tab[staleSlot] = e;
                                              // 这个时候i的slot就是过期的
                          // Start expunge at preceding stale entry if it exists
                          if (slotToExpunge == staleSlot)
                              slotToExpunge = i;
                                              // slotToExpunge是要擦除的起点
                          cleanSomeSlots(expungeStaleEntry(slotToExpunge), len);
                          return;
                      }
      
                      // If we didn't find stale entry on backward scan, the
                      // first stale entry seen while scanning for key is the
                      // first still present in the run.
                      if (k == null && slotToExpunge == staleSlot)
                          slotToExpunge = i;
                  }
      
                            // key并不在map里面;staleSlot是可以插入的,直接插入
                  tab[staleSlot].value = null;
                  tab[staleSlot] = new Entry(key, value);
                              
                  // slotToExpunge是前项已经过期了的,做一些清理工作
                  if (slotToExpunge != staleSlot)
                      cleanSomeSlots(expungeStaleEntry(slotToExpunge), len);
              }
      
      // 启发式地搜索,最多可能搜索o(n),通常情况应该是o(lgn)
      // 如果找到了一个过期的,就把这个过期的元素重新清理,并且把没有过期的重新hash重新插入
      // 关于expungeStaleEntry上文已经分析过了详细的流程
      private boolean cleanSomeSlots(int i, int n) {
                  boolean removed = false;
                  Entry[] tab = table;
                  int len = tab.length;
                  do {
                      i = nextIndex(i, len);
                      Entry e = tab[i];
                      if (e != null && e.get() == null) {
                          n = len;
                          removed = true;
                          i = expungeStaleEntry(i);
                      }
                  } while ( (n >>>= 1) != 0);
                  return removed;
              }
      

      在分析完成ThreadLocal后之后,我提出了我自己的几个 Q && A

      1. 为什么ThreadLocal是线程安全的

      因为ThreadLocal操作的是当前线程的一个threadLocals变量,不同线程操作的是不同的变量,同一时间,一个线程只可能有一个代码序列访问threadLocals.因此ThreadLocal是线程安全的

      1. 在一个线程里面创建无数个ThreadLocal,有没有可能有两个ThreadLocal的key完全一致?

      有可能,因为ThreadLocal的key的hashcode就是从0一直叠加魔法数字,所以创建大量的ThreadLocal可能导致两个key完全一致,但这个场景在实际中实际上不可能,我相信正常的开发同学也不会new异常数量的ThreadLocal的

      1. ThreadLocal的大致原理?

      实际上ThreadLocal就是散列+开放地址法(解决冲突),之所以看ThreadLocal的代码感觉有点复杂,是因为ThreadLocal还处理了每次插入的时候以及获取的时候去删除已经过期了的元素,所以这也是我们将ThreadLocal的key封装弱引用的原因

    相关文章

      网友评论

          本文标题:ThreadLocal分析

          本文链接:https://www.haomeiwen.com/subject/mopfpltx.html