概述
我们Threadlocal类的作用是提供一个线程间隔离,线程内部共享的数据。今天我们一起看看TreadLocal是怎么做到线程隔离的。
例子
例子同样可以在github中找到
public static void testThreadLocal() {
ThreadLocal<Integer> threadLocal = new ThreadLocal<>();
System.out.println(Thread.currentThread().getName() + ".set: " + -1);
threadLocal.set(-1);
ExecutorService executorService = Executors.newCachedThreadPool();
for (int i=1; i< 5; i++) {
final Integer setValue = i;
executorService.submit(() -> {
System.out.println(Thread.currentThread().getName() + ".set: " + setValue);
threadLocal.set(setValue);
System.out.println(Thread.currentThread().getName() + ".get: " + threadLocal.get());
threadLocal.remove();
});
}
System.out.println(Thread.currentThread().getName() + ".get: " + threadLocal.get());
threadLocal.remove();
}
运行结果:
main.set: -1
pool-1-thread-1.set: 1
pool-1-thread-2.set: 2
pool-1-thread-2.get: 2
pool-1-thread-3.set: 3
pool-1-thread-1.get: 1
pool-1-thread-3.get: 3
pool-1-thread-4.set: 4
main.get: -1
pool-1-thread-4.get: 4
代码中threadLocal对象看着也是被多线程竞争写入的,多个线程同时对他进行写入,但每个线程get到的都是正确的结果,为什么可以做到线程隔离呢?
源码
我们先大致看看set方法
public void set(T value) {
//得到当前线程
Thread t = Thread.currentThread();
//获取线程的ThreadLocalMap属性
ThreadLocalMap map = getMap(t);
//map不为空时,set threadlocal 和value
if (map != null)
map.set(this, value);
else
createMap(t, value); //为空时创建一个map并将threadlocal 和value放入
}
ThreadLocalMap getMap(Thread t) {
return t.threadLocals;
}
原来threadLocal set value的时候,首先获得当前的线程对象,然后得到线程对象的ThreadLocalMap属性,然后将threadlocal自身作为key, set到map中。图解一下Thread类和ThreadLocal类的关系。
原来Thread对象中有个ThreadLocalMap属性,ThreadLocalMap顾名思义就是存放ThreadLocal的map。所以虽然例子中看着threadLocal是竞争的写入,其实不是,都是在自己的线程对象中维护了一个threadLocal。
get方法也清晰了,就是从Thread对象里拿key为这个threadLocal对象的 value值呗!
public T get() {
Thread t = Thread.currentThread();
ThreadLocalMap map = getMap(t);
if (map != null) {
//map中当前threadlocal作为key,拿到value的值,并返回
ThreadLocalMap.Entry e = map.getEntry(this);
if (e != null) {
@SuppressWarnings("unchecked")
T result = (T)e.value;
return result;
}
}
return setInitialValue();
}
看到这ThreadLocal类的原理就说完了。
但Doug Lea 哪是一般人,源码中围绕着减少内存泄漏做的很多努力。下面我们就看看为什么会发生内存泄漏,以及怎么防止内存泄漏。
名词解释:
什么是内存泄漏?
- 无用对象(不再使用的对象)持续占有内存或无用对象的内存得不到及时释放,从而造成内存空间的浪费称就叫做内存泄漏。
为什么会发生内存泄漏?
- new出来的ThreadLocal对象有两个地方引用,threadLocal变量和线程属性中的threadLocalMap中的key,如果将threadLocal变量赋值为空后,因为线程的成员变量和线程生命周期相同,垃圾回收器仍然不能回收,造成了内存泄漏。
- 同上即使key指向的threadlocal对象被垃圾回收了,value指向的对象仍然存活着,还是有内存泄漏。
怎么解决?
- 防止threadLocal对象的内存泄漏,使用弱引用。
- 防止value对象的内存泄漏,使用过期检查和清理,以及提供remove方法(后面会详细介绍)。
强软弱虚四种引用的定义及使用场景
- 强引用: 普通的引用,引用存在时,垃圾回收器不能回收。我们new出来的对象都是强引用的。
- 软引用: 垃圾回收后,内存还是不够,进行回收,内存够用是不会回收的(不干掉你我JVM就内存溢出了)。 软引用适合做缓存
- 弱引用: 只有弱引用指向对象时,垃圾回收时就会回收。threadLocal中用于防止内存泄漏。
- 虚引用: 有没有垃圾回收都get不到引用,用于管理直接内存,对象回收时,放入指定队列中,垃圾回收器额外处理指向的直接内存。
- 例子可以在github中查看
上面是方法运行时,栈中内存和堆中内存的示例图,方便我们理解。
为什么弱引用可以帮我们解决key上的内存泄漏呢?
- 根据弱引用的定义,上图中当threadlocal变量指向threadLocal对象的强引用被干掉时(即threadlocal=null),只有map中的key弱弱的指向它,垃圾回收器看它没用了,立马回收掉。这就解决了threadlocal对象的内存泄漏。下面源码看看实现吧
static class ThreadLocalMap {
//这里的源码可以看到map中Entry类继承了WeakReference类,key弱弱的引用ThreadLocal对象
static class Entry extends WeakReference<ThreadLocal<?>> {
/** The value associated with this ThreadLocal. */
Object value;
Entry(ThreadLocal<?> k, Object v) {
super(k);
value = v;
}
}
。。。
}
上面说了防止value对象的内存泄漏,使用过期检查和清理,以及提供remove方法,这里是ThreadLocal最复杂的一部分,我们详细看看吧。再看set方法。
public void set(T value) {
Thread t = Thread.currentThread();
ThreadLocalMap map = getMap(t);
if (map != null)
map.set(this, value);
else
createMap(t, value);
}
调用map的set方法,并不是常用的put方法,看来有不是简单的存值啊
private void set(ThreadLocal<?> key, Object value) {
Entry[] tab = table;
int len = tab.length;
//根据hashcode和Entry数组的长度,计算下标值
int i = key.threadLocalHashCode & (len-1);
//根据得到的下标值找,遇到hash冲突就向后移动一个,直到找到entry是空的节点
for (Entry e = tab[i];
e != null;
e = tab[i = nextIndex(i, len)]) {
ThreadLocal<?> k = e.get();
//遇到key相等,说明key之前存过,替换value值就行了
if (k == key) {
e.value = value;
return;
}
//如果k是空,说明这里存的是一个过期数据,进行替换
//这里会进行过期数据的清理
if (k == null) {
replaceStaleEntry(key, value, i);
return;
}
}
//前面的位置都被占用着,新建一个Entry放在i上
tab[i] = new Entry(key, value);
//将map中的size加1
int sz = ++size;
//扫描清理一次过期数据,如果还是达到扩容的阈值了,进行扩容
//这里也会进行过期数据的清理
if (!cleanSomeSlots(i, sz) && sz >= threshold)
rehash();//先进行一次全量的扫描清理过期数据,还是快接近阈值就扩容
}
replaceStaleEntry方法
private void replaceStaleEntry(ThreadLocal<?> key, Object value,
int staleSlot) {
Entry[] tab = table;
int len = tab.length;
Entry e;
// 向前扫描第一个过期的节点
int slotToExpunge = staleSlot;
for (int i = prevIndex(staleSlot, len);
(e = tab[i]) != null;
i = prevIndex(i, len))
if (e.get() == null)
slotToExpunge = i; //标识第一个需要清除的位置
// 向后遍历
for (int i = nextIndex(staleSlot, len);
(e = tab[i]) != null;
i = nextIndex(i, len)) {
ThreadLocal<?> k = e.get();
// 向后找到了key,把value进行替换
if (k == key) {
e.value = value;
//i节点设为过期数据
tab[i] = tab[staleSlot];
//之前的过期节点赋值为key的Entry数据
tab[staleSlot] = e;
// 如果staleSlot就是第一个过期数据(上面的for进行了一次向前扫描),把过期下标设为i
if (slotToExpunge == staleSlot)
slotToExpunge = i;
//expungeStaleEntry方法清理过期节点,并进行整理(因为存在hash冲突后移,可能某些节点的hash位置空出来了,放入对应的自己的位置,后面会有图解说明)
//cleanSomeSlots会清理Log n次,为了效率不能每次都全量扫描
cleanSomeSlots(expungeStaleEntry(slotToExpunge), len);
return;
}
// staleSlot是第一个过期数据,把slotToExpunge标记为i 说明有其他过期节点
if (k == null && slotToExpunge == staleSlot)
slotToExpunge = i;
}
// 过期位置赋值为用key value构建的新Entry
tab[staleSlot].value = null;
tab[staleSlot] = new Entry(key, value);
// 如果slotToExpunge != staleSlot说明有其他节点也过期了,继续清理一些其他过期节点
//和for循环中slotToExpunge = i 呼应
if (slotToExpunge != staleSlot)
cleanSomeSlots(expungeStaleEntry(slotToExpunge), len);
}
replaceStaleEntry方法顾名思义用当前的key value构造一个entry替换这个过期的Entry节点。但因为存在hash冲突后移,并不能单纯的直接替换,所以做了上面的这么多事情
//清理下标为staleSlot的过期节点
private int expungeStaleEntry(int staleSlot) {
Entry[] tab = table;
int len = tab.length;
// 过期节点设为空
tab[staleSlot].value = null; //help gc
tab[staleSlot] = null;
size--;
// 清理的过程中可能之前因为存在hash冲突后移的节点,位置恰好是staleSlot,staleSlot空出来了,节点应该放在正确的位置。
Entry e;
int i;
//向后扫描
for (i = nextIndex(staleSlot, len);
(e = tab[i]) != null;
i = nextIndex(i, len)) {
ThreadLocal<?> k = e.get();
//节点key为空,说明已过期,直接干掉
if (k == null) {
e.value = null;
tab[i] = null;
size--;
} else {
//计算节点的hash值,确定在数组中的位置
int h = k.threadLocalHashCode & (len - 1);
//如果节点不应该放在i位置上,则可能放在h到i中间的位置上
if (h != i) {
tab[i] = null;
// 从h位置一直后移,找到第一个为空的位置,放在正确的位置上(hash冲突后移的逻辑)
while (tab[h] != null)
h = nextIndex(h, len);
tab[h] = e;
}
}
}
return i;
}
cleanSomeSlots 清理部分过期Entry
//进行log n次扫描
//如果没有发现过期节点返回false(没有节点移动)
//如果发现了过期节点,清理过期节点,n重置为table数组的length,再次扫描log n次
private boolean cleanSomeSlots(int i, int n) {
boolean removed = false;
Entry[] tab = table;
int len = tab.length;
do {
i = nextIndex(i, len);
Entry e = tab[i];
if (e != null && e.get() == null) {
n = len;
removed = true;
i = expungeStaleEntry(i);
}
} while ( (n >>>= 1) != 0);
return removed;
}
上面是set方法中对防止内存泄漏的一些努力,每次set都会对一些过期节点进行清除整理,这一部分也是较难理解的。我们放一张图,方便大家理解。
我们看看get方法,会发现也对防止内存泄漏做了一些努力
public T get() {
Thread t = Thread.currentThread();
ThreadLocalMap map = getMap(t);
if (map != null) {
ThreadLocalMap.Entry e = map.getEntry(this);
if (e != null) {
@SuppressWarnings("unchecked")
T result = (T)e.value;
return result;
}
}
//当map为空时,创建一个map并存入key:this value:null 返回null
return setInitialValue();
}
private Entry getEntry(ThreadLocal<?> key) {
int i = key.threadLocalHashCode & (table.length - 1);
Entry e = table[i];
if (e != null && e.get() == key)
return e;
else
//这里是重点,当hash的位置被其他节点占用了,可能是冲突后移了,可能就是没有
return getEntryAfterMiss(key, i, e);
}
getEntryAfterMiss 方法
private Entry getEntryAfterMiss(ThreadLocal<?> key, int i, Entry e) {
Entry[] tab = table;
int len = tab.length;
//向后找,直到找到entry节点是空时,返回null
while (e != null) {
ThreadLocal<?> k = e.get();
//k正好是我们要找的数据,返回节点entry
if (k == key)
return e;
//如果k是空,说明是过期节点,清除该过期节点
if (k == null)
expungeStaleEntry(i);
else
i = nextIndex(i, len);
e = tab[i];
}
return null;
}
remove方法
//手动清理threadLocal
private void remove(ThreadLocal<?> key) {
Entry[] tab = table;
int len = tab.length;
int i = key.threadLocalHashCode & (len-1);
for (Entry e = tab[i];
e != null;
e = tab[i = nextIndex(i, len)]) {
//得到key对应的entry
if (e.get() == key) {
e.clear(); //将referent赋值为null
expungeStaleEntry(i); //清理该节点
return;
}
}
}
你可能有疑问,既然set和get方法都会移除过期节点,还要我们remove吗?
强烈建议大家使用完threadlocal后一定要调用remove方法。
填坑记
我们曾经一个项目中使用了threadlocal,业务上是这样的
- 根据参数得到一个商家类型,如果是类型A就把A放入threadlocal中,如果是类型B就不放。
- threadlocal是线程间隔离,线程中共享的嘛,后面的代码就可以根据threadlocal.get判断商家类型了。
- 快上线做测试回归的时候,发现总是有概率商家类型会判断错误,一群人加班,后来发现是threadlocal用完后,没有调用remove方法。你知道为什么会这样吗?
- 因为tomcat线程池,线程是重用的,如果线程t1上次使用是被放了A进去,因为t1没有销毁,下次访问A还在里面,即使这次商家类型是B,但B没有重写进去,调用thread.get 得到的仍然是A。所以再次建议大家使用完threadlocal后,一定要进行remove