Redian新闻
>
多线程如何实现事务回滚?一招帮你搞定!

多线程如何实现事务回滚?一招帮你搞定!

公众号新闻

点击上方“芋道源码”,选择“设为星标

管她前浪,还是后浪?

能浪的浪,才是好浪!

每天 10:33 更新文章,每天掉亿点点头发...

源码精品专栏

 
来源:blog.csdn.net/u010978399/
article/details/117771620

特别说明CountDownLatch

CountDownLatch是一个类springboot自带的类,可以直接用 ,变量AtomicBoolean 也是可以直接使用

基于 Spring Boot + MyBatis Plus + Vue & Element 实现的后台管理系统 + 用户小程序,支持 RBAC 动态权限、多租户、数据权限、工作流、三方登录、支付、短信、商城等功能

  • 项目地址:https://github.com/YunaiV/ruoyi-vue-pro
  • 视频教程:https://doc.iocoder.cn/video/

CountDownLatch的用法

CountDownLatch典型用法:

1、某一线程在开始运行前等待n个线程执行完毕。 将CountDownLatch的计数器初始化为new CountDownLatch(n),每当一个任务线程执行完毕,就将计数器减1 countdownLatch.countDown(),当计数器的值变为0时,在CountDownLatch上await()的线程就会被唤醒。一个典型应用场景就是启动一个服务时,主线程需要等待多个组件加载完毕,之后再继续执行。

2、实现多个线程开始执行任务的最大并行性。 注意是并行性,不是并发,强调的是多个线程在某一时刻同时开始执行。类似于赛跑,将多个线程放到起点,等待发令枪响,然后同时开跑。做法是初始化一个共享的CountDownLatch(1),将其计算器初始化为1,多个线程在开始执行任务前首先countdownlatch.await(),当主线程调用countDown()时,计数器变为0,多个线程同时被唤醒。

基于 Spring Cloud Alibaba + Gateway + Nacos + RocketMQ + Vue & Element 实现的后台管理系统 + 用户小程序,支持 RBAC 动态权限、多租户、数据权限、工作流、三方登录、支付、短信、商城等功能

  • 项目地址:https://github.com/YunaiV/yudao-cloud
  • 视频教程:https://doc.iocoder.cn/video/

CountDownLatch(num) 简单说明

new 一个 CountDownLatch(num) 对象

建立对象的时候 num 代表的是需要等待 num 个线程

// 建立对象的时候 num 代表的是需要等待 num 个线程
//主线程
CountDownLatch mainThreadLatch = new CountDownLatch(num);
//子线程
CountDownLatch rollBackLatch  = new CountDownLatch(1);

主线程:mainThreadLatch.await() 和mainThreadLatch.countDown()

新建对象

CountDownLatch mainThreadLatch = new CountDownLatch(num);

卡住主线程,让其等待子线程,代码mainThreadLatch.await(),放在主线程里

mainThreadLatch.await();

代码mainThreadLatch.countDown(),放在子线程里,每一个子线程运行一到这个代码,意味着CountDownLatch(num),里面的num-1(自动减一)

mainThreadLatch.countDown();

CountDownLatch(num)里面的num减到0,也就是CountDownLatch(0),被卡住的主线程mainThreadLatch.await(),就会往下执行

子线程:rollBackLatch.await() 和rollBackLatch.countDown()

新建对象,特别注意:子线程这个num就是1(关于只能为1的解答在后面)

CountDownLatch rollBackLatch  = new CountDownLatch(1);

卡住子线程,阻止每一个子线程的事务提交和回滚

rollBackLatch.await();

代码rollBackLatch.countDown();放在主线程里,而且是放在主线程的等待代码mainThreadLatch.await();后面。

rollBackLatch.countDown();

为什么所有的子线程会在一瞬间就被所有都释放了?

事务的回滚是怎么结合进去的?

假设总共20个子线程,那么其中一个线程报错了怎么实现所有线程回滚。

引入变量

AtomicBoolean rollbackFlag = new AtomicBoolean(false)

和字面意思是一样的:根据 rollbackFlag 的true或者false 判断子线程里面,是否回滚。

首先我们确定的一点:rollbackFlag 是所有的子线程都用着这一个判断

主线程类Entry

package org.apache.dolphinscheduler.api.utils;

import com.alibaba.fastjson.JSONArray;
import com.alibaba.fastjson.JSONObject;
import org.apache.dolphinscheduler.api.controller.WorkThread;
import org.apache.dolphinscheduler.common.enums.DbType;
import org.springframework.web.bind.annotation.*;

import java.text.SimpleDateFormat;
import java.util.ArrayList;
import java.util.Date;
import java.util.List;
import java.util.TimeZone;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.atomic.AtomicBoolean;


@RestController
@RequestMapping("importDatabase")
public class Entry {

    /**
     * @param dbid 数据库的id
     * @param tablename 表名
     * @param sftpFileName 文件名称
     * @param head 是否有头文件
     * @param splitSign 分隔符
     * @param type 数据库类型
     */

    private static String SFTP_HOST = "192.168.1.92";
    private static int SFTP_PORT = 22;
    private static String SFTP_USERNAME = "root";
    private static String SFTP_PASSWORD = "rootroot";
    private static String SFTP_BASEPATH = "/opt/testSFTP/";
    @PostMapping("/thread")
    @ResponseBody
    public static JSONObject importDatabase(@RequestParam("dbid") int dbid
            ,@RequestParam("tablename") String tablename
            ,@RequestParam("sftpFileName") String sftpFileName
            ,@RequestParam("head") String head
            ,@RequestParam("splitSign") String splitSign
            ,@RequestParam("type") DbType type
            ,@RequestParam("heads") String heads
            ,@RequestParam("scolumns") String scolumns
            ,@RequestParam("tcolumns") String tcolumns ) throws Exception 
{
        JSONObject obForRetrun = new JSONObject();

        try {

            JSONArray jsonArray = JSONArray.parseArray(tcolumns);
            JSONArray scolumnArray = JSONArray.parseArray(scolumns);
            JSONArray headsArray = JSONArray.parseArray(heads);
            List<Integer> listInteger = getRrightDataNum(headsArray,scolumnArray);
            JSONArray bodys = SFTPUtils.getSftpContent(SFTP_HOST,SFTP_PORT,SFTP_USERNAME,SFTP_PASSWORD,SFTP_BASEPATH,sftpFileName,head,splitSign);
            int total  = bodys.size();
            int num = 20//一个批次的数据有多少
            int count = total/num;//周期
            int lastNum =total- count*num;//余数

            List<Thread> list = new ArrayList<Thread>();
            SimpleDateFormat sdf = new SimpleDateFormat("HH:mm:ss:SS");
            TimeZone t = sdf.getTimeZone();
            t.setRawOffset(0);
            sdf.setTimeZone(t);
            Long startTime=System.currentTimeMillis();


            int countForCountDownLatch = 0;
            if(lastNum==0){//整除
                countForCountDownLatch= count;
            }else{
                countForCountDownLatch= count + 1;
            }
            //子线程
            CountDownLatch rollBackLatch  = new CountDownLatch(1);
            //主线程
            CountDownLatch mainThreadLatch = new CountDownLatch(countForCountDownLatch);

            AtomicBoolean rollbackFlag = new AtomicBoolean(false);
            StringBuffer message = new StringBuffer();
            message.append("报错信息:");

            //子线程
            for(int i=0;i<count;i++) {//这里的count代表有几个线程
                Thread g = new Thread(new WorkThread(i,num,tablename,jsonArray,dbid,type,bodys,listInteger,mainThreadLatch,rollBackLatch,rollbackFlag,message ));
                g.start();
                list.add(g);
            }

            if(lastNum!=0){//有小数的情况下
                Thread g = new Thread(new WorkThread(0,lastNum,tablename,jsonArray,dbid,type,bodys,listInteger,mainThreadLatch,rollBackLatch,rollbackFlag,message ));
                g.start();
                list.add(g);
            }

//            for(Thread thread:list){
//                System.out.println(thread.getState());
//                thread.join();//是等待这个线程结束;
//            }

            mainThreadLatch.await();
            //所有等待的子线程全部放开
            rollBackLatch.countDown();

            //是主线程等待子线程的终止。也就是说主线程的代码块中,如果碰到了t.join()方法,此时主线程需要等待(阻塞),等待子线程结束了(Waits for this thread to die.),才能继续执行t.join()之后的代码块。


            Long endTime=System.currentTimeMillis();
            System.out.println("总共用时: "+sdf.format(new Date(endTime-startTime)));

            if(rollbackFlag.get()){
                obForRetrun.put("code",500);
                obForRetrun.put("msg",message);
            }else{
                obForRetrun.put("code",200);
                obForRetrun.put("msg","提交成功!");
            }
            obForRetrun.put("data",null);
        }catch (InterruptedException e){
            e.printStackTrace();
            obForRetrun.put("code",500);
            obForRetrun.put("msg",e.getMessage());
            obForRetrun.put("data",null);
        }
        return obForRetrun;

    }

    /**
     * 文件里第几列被作为导出列
     * @param headsArray
     * @param scolumnArray
     * @return
     */

    public static List<Integer> getRrightDataNum(JSONArray headsArray, JSONArray scolumnArray){
        List<Integer> list = new ArrayList<Integer>();
        String arrayA [] = new String[headsArray.size()];
        for(int i=0;i<headsArray.size();i++){
            JSONObject ob  = (JSONObject)headsArray.get(i);
            arrayA[i] =String.valueOf(ob.get("title"));
        }

        String arrayB [] = new String[scolumnArray.size()];
        for(int i=0;i<scolumnArray.size();i++){
            JSONObject ob  = (JSONObject)scolumnArray.get(i);
            arrayB[i] =String.valueOf(ob.get("columnName"));
        }

        for(int i =0;i<arrayA.length;i++){
            for(int j=0;j<arrayB.length;j++){
                if(arrayA[i].equals(arrayB[j])){
                    list.add(i);
                    break;
                }
            }
        }

        return list;
    }
}

子线程类WorkThread

package org.apache.dolphinscheduler.api.controller;

import com.alibaba.fastjson.JSONArray;
import com.alibaba.fastjson.JSONObject;
import org.apache.dolphinscheduler.api.service.DataSourceService;
import org.apache.dolphinscheduler.common.enums.DbType;
import org.apache.dolphinscheduler.dao.entity.DataSource;
import org.apache.dolphinscheduler.dao.mapper.DataSourceMapper;
import org.apache.dolphinscheduler.service.bean.SpringApplicationContext;
import org.springframework.transaction.PlatformTransactionManager;

import java.sql.Connection;
import java.sql.PreparedStatement;
import java.sql.SQLException;
import java.text.ParseException;
import java.text.SimpleDateFormat;
import java.util.Date;
import java.util.List;
import java.util.TimeZone;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.atomic.AtomicBoolean;


/**
 * 多线程
 */

public class WorkThread implements Runnable//建立线程的两种方法 1 实现Runnable 接口 2 继承 Thread 类

    private DataSourceService dataSourceService;

    private DataSourceMapper dataSourceMapper;

    private Integer begin;
    private Integer end;
    private String tableName;
    private JSONArray columnArray;
    private Integer dbid;
    private DbType type;
    private JSONArray bodys;
    private  List<Integer> listInteger;
    private PlatformTransactionManager transactionManager;
    private CountDownLatch mainThreadLatch;
    private CountDownLatch rollBackLatch;
    private AtomicBoolean rollbackFlag;
    private StringBuffer message;



    /**
     * @param i
     * @param num
     * @param tableFrom
     * @param columnArrayFrom
     * @param dbidFrom
     * @param typeFrom
     */

    public WorkThread(int i, int num, String tableFrom, JSONArray columnArrayFrom, int dbidFrom
            , DbType typeFrom, JSONArray bodysFrom, List<Integer> listIntegerFrom
            ,CountDownLatch mainThreadLatch,CountDownLatch rollBackLatch,AtomicBoolean rollbackFlag
            ,StringBuffer messageFrom)
 
{
        begin=i*num;
        end=begin+num;
        tableName = tableFrom;
        columnArray = columnArrayFrom;
        dbid = dbidFrom;
        type = typeFrom;
        bodys = bodysFrom;
        listInteger = listIntegerFrom;
        this.dataSourceMapper = SpringApplicationContext.getBean(DataSourceMapper.class);
        this.dataSourceService = SpringApplicationContext.getBean(DataSourceService.class);
        this.transactionManager = SpringApplicationContext.getBean(PlatformTransactionManager.class);
        this.mainThreadLatch = mainThreadLatch;
        this.rollBackLatch = rollBackLatch;
        this.rollbackFlag = rollbackFlag;
        this.message = messageFrom;
    }

    public void run() {

        DataSource dataSource = dataSourceMapper.queryDataSourceByID(dbid);
        String cp = dataSource.getConnectionParams();
        Connection con=null;
            con =  dataSourceService.getConnection(type,cp);
        if(con!=null)
        {
            SimpleDateFormat sdf = new SimpleDateFormat("HH:mm:ss:SS");
            TimeZone t = sdf.getTimeZone();
            t.setRawOffset(0);
            sdf.setTimeZone(t);
            Long startTime = System.currentTimeMillis();
            try {
                con.setAutoCommit(false);

//---------------------------- 获取字段和类型
                String columnString = null;//活动的字段
                int intForType = 0;
                String type[] = new String[columnArray.size()];//类型集合
                for(int i=0;i<columnArray.size();i++){
                    JSONObject ob = (JSONObject)columnArray.get(i);
                    if(columnString==null){
                        columnString = String.valueOf(ob.get("name"));
                    }else{
                        columnString = columnString + ","+ String.valueOf(ob.get("name"));
                    }
                    type[intForType] = String.valueOf(ob.get("type"));
                    intForType = intForType + 1;
                }
                intForType = 0;

                //这一步是为了形成 insert into "+tableName+"(id,name,age) values (?,?,?);
                String dataString  = null;
                for(int i=0;i<columnArray.size();i++){
                    if(dataString==null){
                        dataString = "?";
                    }else{
                        dataString = dataString +","+"?";
                    }
                }

//--------------------------------

                StringBuffer sql = new StringBuffer();
                sql = sql.append("insert into "+tableName+"("+columnString+") values ("+dataString+")") ;
                PreparedStatement pst= (PreparedStatement)con.prepareStatement(sql.toString());
                for(int i=begin;i<end;i++) {
                    JSONObject ob = (JSONObject)bodys.get(i);
                    if(ob!=null){
                        String [] array = ob.get(i).toString().split("\\,");
                        String [] arrayFinal = getFinalData(listInteger,array);
                        for(int j=0;j<type.length;j++){
                            String typeString  = type[j].toLowerCase();
                            int z = j+1;
                            if("string".equals(typeString)||"varchar".equals(typeString)){
                                pst.setString(z,arrayFinal[j]);//这里的第一个参数 是指 替换第几个?
                            }else if("int".equals(typeString)||"bigint".equals(typeString)){
                                pst.setInt(z,Integer.valueOf(arrayFinal[j]));//这里的第一个参数 是指 替换第几个?
                            }else if("long".equals(typeString)){
                                pst.setLong(z,Long.valueOf(arrayFinal[j]));//这里的第一个参数 是指 替换第几个?
                            }else if("double".equals(typeString)){
                                pst.setDouble(z,Double.parseDouble(arrayFinal[j]));
                            }else if("date".equals(typeString)||"datetime".equals(typeString)){
                                pst.setDate(z, setDateback(arrayFinal[j]));
                            }else if("Timestamp".equals(typeString)){
                                pst.setTimestamp(z, setTimestampback(arrayFinal[j]));
                            }
                        }
                    }
                    pst.addBatch();
                }
                pst.executeBatch();

                mainThreadLatch.countDown();
                rollBackLatch.await();

                if(rollbackFlag.get()){
                    con.rollback();//表示回滚事务;
                }else{
                    con.commit();//事务提交
                }
                con.close();
            } catch (Exception e) {
                System.out.println(e.getMessage());
                message = message.append(e.getMessage());
                rollbackFlag.set(true);
                mainThreadLatch.countDown();
                try {
                    con.close();
                } catch (SQLException throwables) {
                    throwables.printStackTrace();
                }
            }
            Long endTime = System.currentTimeMillis();
            System.out.println(Thread.currentThread().getName()+":startTime= "+sdf.format(new Date(startTime))+",endTime= "+sdf.format(new Date(endTime))
                    +" 用时:"+sdf.format(new Date(endTime - startTime)));

        }
    }


    public java.sql.Date setDateback(String dateString) throws ParseException {
        SimpleDateFormat sdf = new SimpleDateFormat( "yyyy-MM-dd HH:mm:ss" );
        java.util.Date date = sdf.parse( "2015-5-6 10:30:00" );
        long lg = date.getTime();// 日期 转 时间戳
        return new java.sql.Date( lg );
    }

    public java.sql.Timestamp setTimestampback(String dateString) throws ParseException {
        SimpleDateFormat sdf = new SimpleDateFormat( "yyyy-MM-dd HH:mm:ss" );
        java.util.Date date = sdf.parse( "2015-5-6 10:30:00" );
        long lg = date.getTime();// 日期 转 时间戳
        return new java.sql.Timestamp( lg );
    }

    public String [] getFinalData(List<Integer> listInteger,String[] array){
        String [] arrayFinal = new String [listInteger.size()];
        for(int i=0;i<listInteger.size();i++){
            int a = listInteger.get(i);
            arrayFinal[i] = array[a];
        }
        return arrayFinal;
    }
}

代码实际运用踩坑!!!!

还记得这里有个一批次处理多少数据么,我这边设置了20,实际到运用中的时候客户给了个20W的数据,我批次设置为20,那就有1W个子线程!!!!

这还不是最糟糕的,最糟糕的是每个子线程都会创建一个数据库连接,数据库直接被我搞炸了

所以这里需要把:

int num = 20//一个批次的数据有多少

改成:

int num = 20000//一个批次的数据有多少


欢迎加入我的知识星球,一起探讨架构,交流源码。加入方式,长按下方二维码噢

已在知识星球更新源码解析如下:

最近更新《芋道 SpringBoot 2.X 入门》系列,已经 101 余篇,覆盖了 MyBatis、Redis、MongoDB、ES、分库分表、读写分离、SpringMVC、Webflux、权限、WebSocket、Dubbo、RabbitMQ、RocketMQ、Kafka、性能测试等等内容。

提供近 3W 行代码的 SpringBoot 示例,以及超 4W 行代码的电商微服务项目。

获取方式:点“在看”,关注公众号并回复 666 领取,更多内容陆续奉上。

文章有帮助的话,在看,转发吧。

谢谢支持哟 (*^__^*)

微信扫码关注该文公众号作者

戳这里提交新闻线索和高质量文章给我们。
相关阅读
人间再无刘三姐怎么开始学佛(十二)成佛就是成自己不指责不训练,一招帮助孩子解决上课注意力不集中问题不惑创投李祝捷:创业公司如何实现100倍增长?旅游市场象限重构,携程如何站稳“第一区”?链接与共生,城越UrbanLab如何实现建筑碳中和的价值创造敏感的罪魁祸首居然是ta?一招帮你杀菌、除霉、去尘螨、除臭去味!工作中如何时间管理?让《搞定》帮你搞定澳洲Kmart大批商品突然被贴上蓝色标签, 撕下后, 发现事情不简单…Spring在多线程环境下如何确保事务一致性银行倒闭只赔$25万?慌了的美国人发现事实不是这样,但这些账户不行!一篇带你搞定波士顿3W地区公立高中~降本增效成架构师必备技能:酷家乐如何实现全年数据库成本零增长【金融行业】明确分类标准,促进信托业务回归本源—简评《关于规范信托公司信托业务分类的通知》留学生到底怎么报税?超全加拿大报税指南,手把手教你搞定报税季!从简历、技巧到面试题精讲,带你搞定Java面试 | 极客时间银行倒闭只赔$25万?慌了的美国人发现事实不是这样,但这些账户不行通勤太久?半岛、南湾“交通房”地图帮你整理好了,一文统统搞定!麻了,代码改成多线程,竟有9大问题寓意不祥花,无辜任怨嗟离家出走追极光讲座预告 | 老学长支招! 教你搞定选课、科研、GPA!没有邓小平右派慢慢长夜无绝期口臭反反复复?这款源自北大专利研发的牙膏,帮你搞定端午特供|上班族宝妈如何轻松变大厨?一文帮你全搞定~Relate Anything来了!帮你搞定一切关系!用 Copliot 帮你搞定 Java 样板代码牛仔裤别乱买!4个选款教你搞定夏季穿搭,时髦又显瘦!支付系统中,提现流程如何设计纽约花粉过敏季提前来袭,这些小妙招帮你预防过敏冲! 加国全新‘生活补贴’最高可领$400+! 仍被网友吐槽! 没想到这招帮你省钱还健康!细数线程池的10个坑,面试线程不怕不怕啦JUC多线程:CountDownLatch、CyclicBarrier、Semaphore 同步器原理2023湾区最佳公立学区榜单公布:前十学区竟是它?一文统统搞定!不想黑头泛滥,油光满面,不到 40 元帮你搞定
logo
联系我们隐私协议©2024 redian.news
Redian新闻
Redian.news刊载任何文章,不代表同意其说法或描述,仅为提供更多信息,也不构成任何建议。文章信息的合法性及真实性由其作者负责,与Redian.news及其运营公司无关。欢迎投稿,如发现稿件侵权,或作者不愿在本网发表文章,请版权拥有者通知本网处理。