美文网首页
Java 多线程事务回滚 ——多线程插入数据库时事务控制

Java 多线程事务回滚 ——多线程插入数据库时事务控制

作者: 楼兰King | 来源:发表于2021-03-12 16:47 被阅读0次

    背景
    日常项目中,经常会出现一个场景,同时批量插入数据库数据,由于逻辑复杂或者其它原因,我们无法使用sql进行批量插入。串行效率低,耗时长,为了提高效率,这个时候我们首先想到多线程并发插入,但是如何控制事务呢 … 直接上干货

    实现效果
    开启多条子线程,并发插入数据库 当其中一条线程出现异常,或者处理结果为非预期结果,则全部线程均回滚
    代码实现

    @Service
    public class CompanyUserBatchServiceImpl implements CompanyUserBatchService {
        private static final Logger logger = LoggerFactory.getLogger(CompanyUserBatchServiceImpl.class);
    
        @Autowired
        private CompanyUserService companyUserService;
    
        @Override
        public ReturnData addNewCurrentCompanyUsers(String params) {
            logger.info("addNewCompanyUsers 新增参保人方法");
            logger.info(">>>>>>>>>>>>参数:{}", params);
            ReturnData rd = new ReturnData();
            rd.setRetCode(CommonConstants.RETURN_CODE_FAIL);
            if (StringUtils.isBlank(params)) {
                rd.setMsg("入参为空!");
                logger.info(">>>>>>入参为空。");
                return rd;
            }
    
            List<CompanyUserResultVo> companyUsers;
            try {
                companyUsers = JSONObject.parseArray(params, CompanyUserResultVo.class);
            } catch (Exception e) {
                logger.info(">>>>>>>>>入参格式有误: {}", e);
                rd.setMsg("入参格式有误!");
                return rd;
            }
    
    
            //每条线程最小处理任务数
            int perThreadHandleCount = 1;
            //线程池的最大线程数
            int nThreads = 10;
            int taskSize = companyUsers.size();
    
            if (taskSize > nThreads * perThreadHandleCount) {
                perThreadHandleCount = taskSize % nThreads == 0 ? taskSize / nThreads : taskSize / nThreads + 1;
                nThreads = taskSize % perThreadHandleCount == 0 ? taskSize / perThreadHandleCount : taskSize / perThreadHandleCount + 1;
            } else {
                nThreads = taskSize;
            }
    
            logger.info("批量添加参保人taskSize: {}, perThreadHandleCount: {}, nThreads: {}", taskSize, perThreadHandleCount, nThreads);
            CountDownLatch mainLatch = new CountDownLatch(1);
            //监控子线程
            CountDownLatch threadLatch = new CountDownLatch(nThreads);
            //根据子线程执行结果判断是否需要回滚
            BlockingDeque<Boolean> resultList = new LinkedBlockingDeque<>(nThreads);
            //必须要使用对象,如果使用变量会造成线程之间不可共享变量值
            RollBack rollBack = new RollBack(false);
            ExecutorService fixedThreadPool = Executors.newFixedThreadPool(nThreads);
    
            List<Future<List<Object>>> futures = Lists.newArrayList();
            List<Object> returnDataList = Lists.newArrayList();
            //给每个线程分配任务
            for (int i = 0; i < nThreads; i++) {
                int lastIndex = (i + 1) * perThreadHandleCount;
                List<CompanyUserResultVo> companyUserResultVos = companyUsers.subList(i * perThreadHandleCount, lastIndex >= taskSize ? taskSize : lastIndex);
                AddNewCompanyUserThread addNewCompanyUserThread = new AddNewCompanyUserThread(mainLatch, threadLatch, rollBack, resultList, companyUserResultVos);
                Future<List<Object>> future = fixedThreadPool.submit(addNewCompanyUserThread);
                futures.add(future);
            }
    
            /** 存放子线程返回结果. */
            List<Boolean> backUpResult = Lists.newArrayList();
            try {
                //等待所有子线程执行完毕
                boolean await = threadLatch.await(20, TimeUnit.SECONDS);
                //如果超时,直接回滚
                if (!await) {
                    rollBack.setRollBack(true);
                } else {
                    logger.info("创建参保人子线程执行完毕,共 {} 个线程", nThreads);
                    //查看执行情况,如果有存在需要回滚的线程,则全部回滚
                    for (int i = 0; i < nThreads; i++) {
                        Boolean result = resultList.take();
                        backUpResult.add(result);
                        logger.debug("子线程返回结果result: {}", result);
                        if (result) {
                            /** 有线程执行异常,需要回滚子线程. */
                            rollBack.setRollBack(true);
                        }
                    }
                }
            } catch (InterruptedException e) {
                logger.error("等待所有子线程执行完毕时,出现异常");
                throw new SystemException("等待所有子线程执行完毕时,出现异常,整体回滚");
            } finally {
                //子线程再次开始执行
                mainLatch.countDown();
                logger.info("关闭线程池,释放资源");
                fixedThreadPool.shutdown();
            }
    
            /** 检查子线程是否有异常,有异常整体回滚. */
            for (int i = 0; i < nThreads; i++) {
                if (CollectionUtils.isNotEmpty(backUpResult)) {
                    Boolean result = backUpResult.get(i);
                    if (result) {
                        logger.info("创建参保人失败,整体回滚");
                        throw new SystemException("创建参保人失败");
                    }
                } else {
                    logger.info("创建参保人失败,整体回滚");
                    throw new SystemException("创建参保人失败");
                }
            }
    
            //拼接结果
            try {
                for (Future<List<Object>> future : futures) {
                    returnDataList.addAll(future.get());
                }
            } catch (Exception e) {
                logger.info("获取子线程操作结果出现异常,创建的参保人列表: {} ,异常信息: {}", JSONObject.toJSONString(companyUsers), e);
                throw new SystemException("创建参保人子线程正常创建参保人成功,主线程出现异常,回滚失败");
            }
    
            rd.setRetCode(CommonConstants.RETURN_CODE_SUCCESS);
            rd.setData(returnDataList);
            return rd;
        }
    
        public class AddNewCompanyUserThread implements Callable<List<Object>> {
            /**
             * 主线程监控
             */
            private CountDownLatch mainLatch;
            /**
             * 子线程监控
             */
            private CountDownLatch threadLatch;
            /**
             * 是否回滚
             */
            private RollBack rollBack;
            private BlockingDeque<Boolean> resultList;
            private List<CompanyUserResultVo> taskList;
    
            public AddNewCompanyUserThread(CountDownLatch mainLatch, CountDownLatch threadLatch, RollBack rollBack, BlockingDeque<Boolean> resultList, List<CompanyUserResultVo> taskList) {
                this.mainLatch = mainLatch;
                this.threadLatch = threadLatch;
                this.rollBack = rollBack;
                this.resultList = resultList;
                this.taskList = taskList;
            }
    
            @Override
            public List<Object> call() {
                //为了保证事务不提交,此处只能调用一个有事务的方法,spring 中事务的颗粒度是方法,只有方法不退出,事务才不会提交
                return companyUserService.addNewCompanyUsers(mainLatch, threadLatch, rollBack, resultList, taskList);
            }
    
        }
    
        public class RollBack {
            private Boolean isRollBack;
    
            public Boolean getRollBack() {
                return isRollBack;
            }
    
            public void setRollBack(Boolean rollBack) {
                isRollBack = rollBack;
            }
    
            public RollBack(Boolean isRollBack) {
                this.isRollBack = isRollBack;
            }
        }
    
    public List<Object> addNewCompanyUsers(CountDownLatch mainLatch, CountDownLatch threadLatch, CompanyUserBatchServiceImpl.RollBack rollBack, BlockingDeque<Boolean> resultList, List<CompanyUserResultVo> taskList) {
            List<Object> returnDataList = Lists.newArrayList();
            Boolean result = false;
            logger.debug("线程: {}创建参保人条数 : {}", Thread.currentThread().getName(), taskList.size());
            try {
                for (CompanyUserResultVo companyUserResultVo : taskList) {
                    ReturnData returnData = addSingleCompanyUser(companyUserResultVo);
                    if (returnData.getRetCode() == CommonConstants.RETURN_CODE_FAIL) {
                        result = true;
                    }
                    returnDataList.add(returnData.getData());
                }
                //Exception 和 Error 都需要抓
            } catch (Throwable throwable) {
                throwable.printStackTrace();
                logger.info("线程: {}创建参保人出现异常: {} ", Thread.currentThread().getName(), throwable);
                result = true;
            }
    
            resultList.add(result);
            threadLatch.countDown();
            logger.info("子线程 {} 计算过程已经结束,等待主线程通知是否需要回滚", Thread.currentThread().getName());
    
            try {
                mainLatch.await();
                logger.info("子线程 {} 再次启动", Thread.currentThread().getName());
            } catch (InterruptedException e) {
                logger.error("批量创建参保人线程InterruptedException异常");
                throw new SystemException("批量创建参保人线程InterruptedException异常");
            }
    
            if (rollBack.getRollBack()) {
                logger.error("批量创建参保人线程回滚, 线程: {}, 需要更新的信息taskList: {}",
                        Thread.currentThread().getName(),
                        JSONObject.toJSONString(taskList));
                logger.info("子线程 {} 执行完毕,线程退出", Thread.currentThread().getName());
                throw new SystemException("批量创建参保人线程回滚");
            }
    
            logger.info("子线程 {} 执行完毕,线程退出", Thread.currentThread().getName());
            return returnDataList;
        }
    

    思想就是使用两个CountDownWatch实现子线程的二段提交
    步骤:

    主线程将任务分发给子线程,然后 使用 boolean await = threadLatch.await(20, TimeUnit.SECONDS); 阻塞主线程,等待所有子线程处理向数据库中插入的业务 使用 threadLatch.countDown(); 释放子线程锁定,同时使用 mainLatch.await(); 阻塞子线程,将程序的控制权交还给主线程 主线程检查子线程执行插入数据库的结果,若有非预期结果出现,主线程标记状态告知子线程回滚,然后使用 mainLatch.countDown(); 将程序控制权再次交给子线程,子线程检测回滚标志,判断是否回滚 子线程执行结束,主线程拼接处理结果,响应给请求方
    整个过程类似于GC的标记-清除过程(串行的垃圾收集器)

    相关文章

      网友评论

          本文标题:Java 多线程事务回滚 ——多线程插入数据库时事务控制

          本文链接:https://www.haomeiwen.com/subject/sipzqltx.html