解决方式:
- 线程池装饰器ThreadPoolTaskExecutor
- 阿里开源框架ttl , TransmittableThreadLocal
核心代码
public class UserContextHolder {
//ThreadLocal
private static final ThreadLocal<UserInfoContext> userContext = new ThreadLocal<>();
//方式2优化: 使用ttl 存储上下文信息,可以穿透到异步线程里
private static TransmittableThreadLocal<UserInfoContext> userContext2 = new TransmittableThreadLocal<>();
public UserContextHolder() {
}
public static UserInfoContext getUserInfoContext() {
return userContext.get();
}
public static void setUserInfoContext(UserInfoContext userInfoContext) {
userContext.set(userInfoContext);
}
public static void remove() {
userContext.remove();
}
public static UserInfoContext getUserInfoContext2() {
return userContext2.get();
}
public static void setUserInfoContext2(UserInfoContext userInfoContext) {
userContext2.set(userInfoContext);
}
public static void remove2() {
userContext2.remove();
}
}
@AllArgsConstructor
@Data
public class UserInfoContext {
private String token;
private String userId;
private String userName;
}
public class UserContextInterceptor implements HandlerInterceptor, Ordered {
@Override
public boolean preHandle(HttpServletRequest request, HttpServletResponse response, Object handler) throws Exception {
String token = request.getHeader("token");
//根据token获取用户信息,从数据库获取
UserInfoContext userInfoContext = new UserInfoContext(token, "userId-123", "userName-lixi");
//设置上下文信息
UserContextHolder.setUserInfoContext(userInfoContext);
//方式2 设置上下文信息ttl
UserContextHolder.setUserInfoContext2(userInfoContext);
return true;
}
@Override
public void afterCompletion(HttpServletRequest request, HttpServletResponse response, Object handler, Exception ex) throws Exception {
//必须threadLoacl清除,防止内存泄漏
UserContextHolder.remove();
//ttl 也需要清除
UserContextHolder.remove2();
}
@Override
public int getOrder() {
return 0;
}
}
public class TestUserContextUtil {
//默认线程池
private static ThreadPoolExecutor threadPoolExecutor = null;
//使用装饰器 线程池
private static ThreadPoolTaskExecutor threadPoolTaskExecutor = null;
static {
threadPoolExecutor = new ThreadPoolExecutor(10,
20,
10,
TimeUnit.SECONDS,
new LinkedBlockingDeque<>(100),
new ThreadFactoryBuilder().setNameFormat("异步线程池-%d").build(),
new ThreadPoolExecutor.AbortPolicy()
);
//方式1优化:使用ThreadPoolTaskExecutor 设置装饰器
threadPoolTaskExecutor = new ThreadPoolTaskExecutor();
threadPoolTaskExecutor.setCorePoolSize(10);
threadPoolTaskExecutor.setMaxPoolSize(20);
threadPoolTaskExecutor.setKeepAliveSeconds(10);
threadPoolTaskExecutor.setQueueCapacity(100);
threadPoolTaskExecutor.setThreadFactory(new ThreadFactoryBuilder().setNameFormat("异步线程池-%d").build());
threadPoolTaskExecutor.setRejectedExecutionHandler(new ThreadPoolExecutor.AbortPolicy());
threadPoolTaskExecutor.setTaskDecorator(new BusinessContextDecorator());
threadPoolTaskExecutor.initialize();
}
public static void main(String[] args) {
//主线程模拟 拦截器设置 用户上下文信息
UserInfoContext userInfoContext = new UserInfoContext("token-123", "userId-123", "userName-lixi");
UserContextHolder.setUserInfoContext(userInfoContext);
getSubThreadContext();
getAsyncThreadContext();
getAsyncThreadPool();
getAsyncThreadPool2();
}
//模拟 同步方法调用 获取上下文
public static void getSubThreadContext() {
System.out.println("同步方法调用获取上下文:"+UserContextHolder.getUserInfoContext());
}
//模拟 异步方法调用 获取上下文 [普通线程实现]
public static void getAsyncThreadContext() {
new Thread(() -> {
try {
TimeUnit.MICROSECONDS.sleep(500);
} catch (InterruptedException e) {
e.printStackTrace();
}
System.out.println("异步线程 方法调用获取上下文:"+UserContextHolder.getUserInfoContext());
}).start();
}
//模拟threadPoolExecutor 异步方法调用 获取上下文
public static void getAsyncThreadPool() {
threadPoolExecutor.execute(() -> {
System.out.println("异步线程 方法调用获取上下文:"+UserContextHolder.getUserInfoContext());
});
}
//测试方式1 装饰器
//模拟threadPoolTaskExecutor 异步方法调用 获取上下文
public static UserInfoContext getAsyncThreadPool2() {
Future<UserInfoContext> future = threadPoolTaskExecutor.submit(() -> {
System.out.println("异步线程-装饰器 方法调用获取上下文:" + UserContextHolder.getUserInfoContext());
return UserContextHolder.getUserInfoContext();
});
try {
UserInfoContext userInfoContext = future.get(10, TimeUnit.SECONDS);
return userInfoContext;
} catch (Exception e) {
e.printStackTrace();
}
return null;
}
//测试方式2 ttl
//模拟threadPoolExecutor 异步方法调用 获取上下文, 上下文使用ttl存储
public static UserInfoContext getAsyncThreadPool3() {
Future<UserInfoContext> future = threadPoolTaskExecutor.submit(() -> {
System.out.println("异步线程-ttl 方法调用获取上下文:" + UserContextHolder.getUserInfoContext2());
return UserContextHolder.getUserInfoContext();
});
try {
UserInfoContext userInfoContext = future.get(10, TimeUnit.SECONDS);
return userInfoContext;
} catch (Exception e) {
e.printStackTrace();
}
return null;
}
}
public class BusinessContextDecorator implements TaskDecorator {
@Override
public Runnable decorate(Runnable runnable) {
//主线程的上下文
UserInfoContext userContext = UserContextHolder.getUserInfoContext();
//主线程的traceId
// Map<String, String> copyOfContextMap = MDC.getCopyOfContextMap();
return () -> {
//将执行任务前,保存主线程 上下文
UserContextHolder.setUserInfoContext(userContext);
// MDC.setContextMap(copyOfContextMap);
runnable.run();
};
}
}
测试代码
@GetMapping("/test/userContext")
public String testUserContext() {
UserInfoContext userInfoContext = userContextService.getAsyncThreadPool();
return userInfoContext.toString();
}
//使用ttl
@GetMapping("/test/userContext/ttl")
public String testUserContext2() {
UserInfoContext userInfoContext = userContextService.getAsyncThreadPool3();
return userInfoContext.toString();
}
验证效果
异步线程-装饰器 方法调用获取上下文:UserInfoContext(token=111, userId=userId-123, userName=userName-lixi)
异步线程-ttl 方法调用获取上下文:UserInfoContext(token=111, userId=userId-123, userName=userName-lixi)