背景
在项目中使用多线程抓取第三方数据执行数据入库时,如果某个子线程执行异常,其他子线事务全部回滚,spring对多线程无法进行事务控制,是因为多线程底层连接数据库的时候,是使用的线程变量(TheadLocal),线程之间事务隔离,每个线程有自己的连接,事务肯定不是同一个了。
解决办法
思想就是使用两个CountDownLatch实现子线程的二段提交
步骤:
1、主线程将任务分发给子线程,然后使用childMonitor.await();阻塞主线程,等待所有子线程处理向数据库中插入的业务,并使用BlockingDeque存储线程的返回结果。
2、使用childMonitor.countDown()释放子线程锁定,同时使用mainMonitor.await();阻塞子线程,将程序的控制权交还给主线程。
3、主线程检查子线程执行任务的结果,若有失败结果出现,主线程标记状态告知子线程回滚,然后使用mainMonitor.countDown();将程序控制权再次交给子线程,子线程检测回滚标志,判断是否回滚。
代码实现
线程池工具类
publicclassThreadPoolTool {
/** * 多线程任务
* @param transactionManager
* @param data
* @param threadCount
* @param params
* @param clazz
*/publicvoidexcuteTask(DataSourceTransactionManager transactionManager, List data,intthreadCount, Map params, Class clazz) {
if(data ==null|| data.size() == 0) {
return;
}
intbatch = 0;
ExecutorService executor = Executors.newFixedThreadPool(threadCount);
//监控子线程的任务执行CountDownLatch childMonitor =new CountDownLatch(threadCount);
//监控主线程,是否需要回滚CountDownLatch mainMonitor =newCountDownLatch(1);
//存储任务的返回结果,返回true表示不需要回滚,反之,则回滚BlockingDeque results =newLinkedBlockingDeque(threadCount);
RollBack rollback =newRollBack(false);
try {
LinkedBlockingQueue queue = splitQueue(data, threadCount);
while(true) {
List list = queue.poll();
if(list ==null) {
break;
}
batch++;
params.put("batch", batch);
Constructor constructor = clazz.getConstructor(newClass[]{CountDownLatch.class, CountDownLatch.class, BlockingDeque.class, RollBack.class, DataSourceTransactionManager.class, Object.class, Map.class});
ThreadTask task = (ThreadTask) constructor.newInstance(childMonitor, mainMonitor, results, rollback, transactionManager, list, params);
executor.execute(task);
}
// 1、主线程将任务分发给子线程,然后使用childMonitor.await();阻塞主线程,等待所有子线程处理向数据库中插入的业务。 childMonitor.await();
System.out.println("主线程开始执行任务");
//根据返回结果来确定是否回滚for(inti = 0; i < threadCount; i++) {
Boolean result = results.take();
if(!result) {
//有线程执行异常,需要回滚子线程rollback.setNeedRoolBack(true);
}
}
// 3、主线程检查子线程执行任务的结果,若有失败结果出现,主线程标记状态告知子线程回滚,然后使用mainMonitor.countDown();将程序控制权再次交给子线程,子线程检测回滚标志,判断是否回滚。 mainMonitor.countDown();
} catch (Exception e) {
log.error(e.getMessage());
} finally {
//关闭线程池,释放资源 executor.shutdown();
}
}
/** * 队列拆分
*
* @param data 需要执行的数据集合
* @param threadCount 核心线程数
* @return*/privateLinkedBlockingQueue> splitQueue(List data,int threadCount) {
LinkedBlockingQueue> queueBatch =new LinkedBlockingQueue();
inttotal = data.size();
intoneSize = total / threadCount;
int start;
int end;
for(inti = 0; i < threadCount; i++) {
start = i * oneSize;
end = (i + 1) * oneSize;
if(i < threadCount - 1) {
queueBatch.add(data.subList(start, end));
} else {
queueBatch.add(data.subList(start, data.size()));
}
}
return queueBatch;
}
}
子线程任务执行类
publicabstractclassThreadTaskimplements Runnable {
/** * 监控子任务的执行
*/private CountDownLatch childMonitor;
/** * 监控主线程
*/private CountDownLatch mainMonitor;
/** * 存储线程的返回结果
*/privateBlockingDeque resultList;
/** * 回滚类
*/private RollBack rollback;
privateMap params;
protected Object obj;
protected DataSourceTransactionManager transactionManager;
protected TransactionStatus status;
publicThreadTask(CountDownLatch childCountDown, CountDownLatch mainCountDown, BlockingDeque result, RollBack rollback, DataSourceTransactionManager transactionManager, Object obj,Map params) {
this.childMonitor = childCountDown;
this.mainMonitor = mainCountDown;
this.resultList = result;
this.rollback = rollback;
this.transactionManager = transactionManager;
this.obj = obj;
this.params = params;
initParam();
}
/** * 事务回滚
*/privatevoid rollBack() {
System.out.println(Thread.currentThread().getName()+"开始回滚");
transactionManager.rollback(status);
}
/** * 事务提交
*/privatevoid submit() {
System.out.println(Thread.currentThread().getName()+"提交事务");
transactionManager.commit(status);
}
protected Object getParam(String key){
return params.get(key);
}
publicabstractvoid initParam();
/** * 执行任务,返回false表示任务执行错误,需要回滚
* @return*/publicabstractboolean processTask();
@Override
publicvoid run() {
System.out.println(Thread.currentThread().getName()+"子线程开始执行任务");
DefaultTransactionDefinition def =new DefaultTransactionDefinition();
def.setPropagationBehavior(TransactionDefinition.PROPAGATION_REQUIRED);
status = transactionManager.getTransaction(def);
Boolean result = processTask();
//向队列中添加处理结果 resultList.add(result);
//2、使用childMonitor.countDown()释放子线程锁定,同时使用mainMonitor.await();阻塞子线程,将程序的控制权交还给主线程。 childMonitor.countDown();
try {
//等待主线程的判断逻辑执行完,执行下面的是否回滚逻辑 mainMonitor.await();
} catch (Exception e) {
log.error(e.getMessage());
}
System.out.println(Thread.currentThread().getName()+"子线程执行剩下的任务");
//3、主线程检查子线程执行任务的结果,若有失败结果出现,主线程标记状态告知子线程回滚,然后使用mainMonitor.countDown();将程序控制权再次交给子线程,子线程检测回滚标志,判断是否回滚。if (rollback.isNeedRoolBack()) {
rollBack();
}else{
//事务提交 submit();
}
}
回滚标记类
@Datapublicclass RollBack {
publicRollBack(boolean needRoolBack) {
this.needRoolBack = needRoolBack;
}
privateboolean needRoolBack;
}
使用线程池工具:
1,首先建立自己的任务执行类 并且 extends ThreadTask ,实现initParam()和processTask()方法
/** * 多线程处理任务类
*/publicclassTestTaskextends ThreadTask {
/** 分批处理的数据
*/privateList objectList;
/** * 可能需要注入的某些服务
*/private TestService testService;
publicTestTask(CountDownLatch childCountDown, CountDownLatch mainCountDown, BlockingDeque result, RollBack rollback, DataSourceTransactionManager transactionManager, Object obj, Map params) {
super(childCountDown, mainCountDown, result, rollback, transactionManager, obj, params);
}
@Override
publicvoid initParam() {
this.objectList = (List) getParam("objectList");
this.testService = (TestService) getParam("testService");
}
/** * 执行任务,返回false表示任务执行错误,需要回滚
* @return*/ @Override
publicboolean processTask() {
try {
for (Object o : objectList) {
testService.list();
System.out.println(o.toString()+"执行自己的多线程任务逻辑");
}
returntrue;
} catch (Exception e) {
returnfalse;
}
}
}
2,编写主任务执行方法
/** * 执行多线程任务方法
*/publicvoid testThreadTask() {
try {
intthreadCount = 5;
//需要分批处理的数据List objectList =newArrayList<>();
Map params =newHashMap<>();
params.put("objectList",objectList);
params.put("testService",testService);
//调用多线程工具方法threadPoolTool.excuteTask(transactionManager,objectList,threadCount,params, TestTask.class);
}catch (Exception e){
thrownew RuntimeException(e.getMessage());
}
}