基于MQ的2PC分布式事务

作者: 套马杆的程序员 | 来源:发表于2021-01-22 17:06 被阅读0次
    ​​​​ 在这里插入图片描述

    上图阐释了如何基于mq实现2pc的分布式事务

    • 一阶段红线部分。
    • 二阶段为蓝线部分。

    图中展示了较为复杂的调用方式,S1调用S2、S3,S3又调用了S4。
    感谢seata开源社区大佬的帮助。虽然2pc本身存在很多问题,但是自己手动实现一遍还是学习到很多。
    本文仅做参考,不具备生产意义。
    seata社区陈建斌大佬指正的问题列表如下:
    问题
    第一:tm需要有事务记录表,来恢复事务,而且要考虑到rm没任何异常,只是因为tm宕机导致tm的二阶段提交没有入库,但是由于这样,rm本身应该提交的事务变成了回滚。
    第二:需要把connection换为xaconnection,使用xa协议来保证rm宕机后事务数据可恢复。
    第三:要保证消息队列中间件的高可用。
    第四:要防止资源悬挂问题,因为没有了分支事务注册,很可能因为网络或者其它因素,先发后置了,导致了tm没感知到这个rm的存在,这个rm就可能因为用了xa协议导致死锁。

    show your code

    根据上图我们可以很好的实现代码如下:此处基于rocketmq方式实现。

    引入以下包

    <dependency>
                <groupId>org.springframework.boot</groupId>
                <artifactId>spring-boot-starter-aop</artifactId>
            </dependency>
            <dependency>
                <groupId>org.apache.dubbo</groupId>
                <artifactId>dubbo</artifactId>
                <version>2.7.2</version>
                <scope>provided</scope>
            </dependency>
            <dependency>
                <groupId>org.apache.rocketmq</groupId>
                <artifactId>rocketmq-spring-boot-starter</artifactId>
                <version>2.1.1</version>
            </dependency>
    
    

    全局事务注解此注解开启全局事务,真正事务还是交给Transactional注解去执行

    package com.xxx.mq.trx.config;
    
    import java.lang.annotation.ElementType;
    import java.lang.annotation.Inherited;
    import java.lang.annotation.Retention;
    import java.lang.annotation.RetentionPolicy;
    import java.lang.annotation.Target;
    
    /**
     * @Description TODO
     * @Author 姚仲杰
     * @Date 2021/1/2 21:36
     */
    @Target(ElementType.METHOD)
    @Retention(RetentionPolicy.RUNTIME)
    @Inherited
    public @interface GlobalTransaction {
    
    }
    
    

    全局事务切面

    package com.xxx.mq.trx.aspect;
    
    import com.xxx.mq.trx.config.TransactionConst;
    import com.xxx.mq.trx.core.TrxContextHolder;
    import java.util.HashMap;
    import java.util.Map;
    import java.util.UUID;
    import org.apache.rocketmq.spring.core.RocketMQTemplate;
    import org.aspectj.lang.ProceedingJoinPoint;
    import org.aspectj.lang.annotation.Around;
    import org.aspectj.lang.annotation.Pointcut;
    import org.springframework.beans.factory.annotation.Autowired;
    import org.springframework.messaging.Message;
    import org.springframework.messaging.support.MessageBuilder;
    import org.springframework.util.StringUtils;
    
    /**
     * @Description TODO
     * @Author 姚仲杰
     * @Date 2021/1/2 21:38
     */
    public class GlobalTrxAspect {
    
        @Autowired
        RocketMQTemplate rocketMQTemplate;
    
        @Pointcut("@annotation(com.xxx.mq.trx.config.GlobalTransaction)")
        public void pointcut(){}
    
        @Around("pointcut()")
        public void around(ProceedingJoinPoint joinPoint) throws Throwable {
            //方法执行前需生成trx_id
            //判断是否事务发起者,如果能从线程上下文取到事务id说明是参与者,如果取不到则是事务管理者。
            String trx_id = TrxContextHolder.getTrxId();
            boolean isManager = false;
            if (StringUtils.isEmpty(trx_id)) {
                UUID uuid = UUID.randomUUID();
                TrxContextHolder.setTrxId(uuid.toString());
                isManager=true;
            }
            Map map=new HashMap(2);
            map.put(TransactionConst.TRX_ID,trx_id);
            try {
                joinPoint.proceed();
                map.put(trx_id, TransactionConst.COMMIT);
            } catch (Throwable throwable) {
                map.put(trx_id, TransactionConst.ROLLBACK);
                throw throwable;
            }finally {
                //方法执行后需发送消息告知所有事务参与者是提交还是回滚
                if(isManager) {
                    Message msg = MessageBuilder.withPayload(map).build();
                    rocketMQTemplate.send(TransactionConst.TRX_TOPIC, msg);
                }
            }
        }
    }
    
    

    事务常量定义

    package com.xxx.mq.trx.config;
    
    /**
     * @Description TODO
     * @Author 姚仲杰
     * @Date 2021/1/4 9:28
     */
    public interface TransactionConst {
        int COMMIT=1;
        int ROLLBACK=0;
        String TRX_ID="trx_id";
        String TRX_TOPIC="global_trx_topic";
        String TRX_GROUP="global_trx_group";
    }
    
    
    package com.xxx.mq.trx.aspect;
    
    import com.xxx.mq.trx.core.ConnectionProxy;
    import com.xxx.mq.trx.core.TrxContextHolder;
    import java.sql.Connection;
    import java.util.ArrayList;
    import java.util.List;
    import java.util.concurrent.locks.ReentrantLock;
    import org.apache.commons.lang3.StringUtils;
    import org.aspectj.lang.ProceedingJoinPoint;
    import org.aspectj.lang.annotation.Around;
    import org.aspectj.lang.annotation.Aspect;
    import org.springframework.beans.factory.ObjectFactory;
    import org.springframework.beans.factory.annotation.Autowired;
    import org.springframework.stereotype.Component;
    
    /**
     * @Description 拦截getConnection调用用于处理事务手动提交
     * @Author 姚仲杰
     * @Date 2021/01/04 11:46
     */
    @Aspect
    @Component
    public class DataSourceAspect {
        @Autowired
        ObjectFactory<ConnectionProxy> bean;
    
        ReentrantLock lock = new ReentrantLock();
    
        @Around("execution(* javax.sql.DataSource.getConnection(..))")
        public Object around(ProceedingJoinPoint joinPoint) throws Throwable {
            Connection conn = (Connection)joinPoint.proceed();
            String trxId=TrxContextHolder.getTrxId();
            if (StringUtils.isNotBlank(trxId)) {
                ConnectionProxy connectionProxy = bean.getObject();
                connectionProxy.setConnection(conn);
                lock.lock();
                try {
                    List<ConnectionProxy> list = TrxContextHolder.getConnections(trxId);
                    if (list == null) {
                        list = new ArrayList<>();
                    }
                    list.add(connectionProxy);
                    TrxContextHolder.setConnections(trxId,list);
                } finally {
                    lock.unlock();
                }
                return connectionProxy;
            }
            return conn;
        }
    
    }
    
    

    连接代理让Transactional注解的事务提交执行个寂寞,然后转交由我们自己mq通知提交。

    package com.xxx.mq.trx.core;
    
    import java.sql.Array;
    import java.sql.Blob;
    import java.sql.CallableStatement;
    import java.sql.Clob;
    import java.sql.Connection;
    import java.sql.DatabaseMetaData;
    import java.sql.NClob;
    import java.sql.PreparedStatement;
    import java.sql.SQLClientInfoException;
    import java.sql.SQLException;
    import java.sql.SQLWarning;
    import java.sql.SQLXML;
    import java.sql.Savepoint;
    import java.sql.Statement;
    import java.sql.Struct;
    import java.util.Map;
    import java.util.Properties;
    import java.util.concurrent.Executor;
    import org.slf4j.Logger;
    import org.slf4j.LoggerFactory;
    import org.springframework.beans.factory.config.ConfigurableBeanFactory;
    import org.springframework.context.annotation.Scope;
    import org.springframework.stereotype.Component;
    
    /**
     * @Description TODO
     * @Author 姚仲杰
     * @Date 2021/1/4 10:48
     */
    @Component
    @Scope(value = ConfigurableBeanFactory.SCOPE_PROTOTYPE)
    public class ConnectionProxy implements Connection {
        private static final Logger LOGGER = LoggerFactory.getLogger(ConnectionProxy.class);
    
        private Connection connection;
        //mq收到事务通知之后调用此方法执行提交或回滚
        public void notify(int state) {
            try {
                if (state == 1) {
                    connection.commit();
                } else {
                    connection.rollback();
                }
                connection.close();
            } catch (Exception e) {
                LOGGER.error(e.getLocalizedMessage(), e);
            }
        }
    
        @Override
        public void setAutoCommit(boolean autoCommit) throws SQLException {
            connection.setAutoCommit(false);
        }
    
        @Override
        public void commit() throws SQLException {
            // connection.commit();
        }
    
        @Override
        public void rollback() throws SQLException {
            // connection.rollback();
        }
    
        @Override
        public void close() throws SQLException {
            // connection.close();
        }
    
        @Override
        public boolean getAutoCommit() throws SQLException {
            return connection.getAutoCommit();
        }
    
        @Override
        public Statement createStatement() throws SQLException {
            return connection.createStatement();
        }
    
        @Override
        public PreparedStatement prepareStatement(String sql) throws SQLException {
            return connection.prepareStatement(sql);
        }
    
        @Override
        public CallableStatement prepareCall(String sql) throws SQLException {
            return connection.prepareCall(sql);
        }
    
        @Override
        public String nativeSQL(String sql) throws SQLException {
            return connection.nativeSQL(sql);
        }
    
        @Override
        public boolean isClosed() throws SQLException {
            return connection.isClosed();
        }
    
        @Override
        public DatabaseMetaData getMetaData() throws SQLException {
            return connection.getMetaData();
        }
    
        @Override
        public void setReadOnly(boolean readOnly) throws SQLException {
            connection.setReadOnly(readOnly);
        }
    
        @Override
        public boolean isReadOnly() throws SQLException {
            return connection.isReadOnly();
        }
    
        @Override
        public void setCatalog(String catalog) throws SQLException {
            connection.setCatalog(catalog);
        }
    
        @Override
        public String getCatalog() throws SQLException {
            return connection.getCatalog();
        }
    
        @Override
        public void setTransactionIsolation(int level) throws SQLException {
            connection.setTransactionIsolation(level);
        }
    
        @Override
        public int getTransactionIsolation() throws SQLException {
            return connection.getTransactionIsolation();
        }
    
        @Override
        public SQLWarning getWarnings() throws SQLException {
            return connection.getWarnings();
        }
    
        @Override
        public void clearWarnings() throws SQLException {
            connection.clearWarnings();
        }
    
        @Override
        public Statement createStatement(int resultSetType, int resultSetConcurrency) throws SQLException {
            return connection.createStatement(resultSetType, resultSetConcurrency);
        }
    
        @Override
        public PreparedStatement prepareStatement(String sql, int resultSetType, int resultSetConcurrency)
            throws SQLException {
            return connection.prepareStatement(sql, resultSetType, resultSetConcurrency);
        }
    
        @Override
        public CallableStatement prepareCall(String sql, int resultSetType, int resultSetConcurrency) throws SQLException {
            return connection.prepareCall(sql, resultSetType, resultSetConcurrency);
        }
    
        @Override
        public Map<String, Class<?>> getTypeMap() throws SQLException {
            return connection.getTypeMap();
        }
    
        @Override
        public void setTypeMap(Map<String, Class<?>> map) throws SQLException {
            connection.setTypeMap(map);
        }
    
        @Override
        public void setHoldability(int holdability) throws SQLException {
            connection.setHoldability(holdability);
        }
    
        @Override
        public int getHoldability() throws SQLException {
            return connection.getHoldability();
        }
    
        @Override
        public Savepoint setSavepoint() throws SQLException {
            return connection.setSavepoint();
        }
    
        @Override
        public Savepoint setSavepoint(String name) throws SQLException {
            return connection.setSavepoint(name);
        }
    
        @Override
        public void rollback(Savepoint savepoint) throws SQLException {
            connection.rollback(savepoint);
        }
    
        @Override
        public void releaseSavepoint(Savepoint savepoint) throws SQLException {
            connection.releaseSavepoint(savepoint);
        }
    
        @Override
        public Statement createStatement(int resultSetType, int resultSetConcurrency, int resultSetHoldability)
            throws SQLException {
            return connection.createStatement(resultSetType, resultSetConcurrency, resultSetHoldability);
        }
    
        @Override
        public PreparedStatement prepareStatement(String sql, int resultSetType, int resultSetConcurrency,
            int resultSetHoldability) throws SQLException {
            return connection.prepareStatement(sql, resultSetType, resultSetConcurrency, resultSetHoldability);
        }
    
        @Override
        public CallableStatement prepareCall(String sql, int resultSetType, int resultSetConcurrency,
            int resultSetHoldability) throws SQLException {
            return connection.prepareCall(sql, resultSetType, resultSetConcurrency, resultSetHoldability);
        }
    
        @Override
        public PreparedStatement prepareStatement(String sql, int autoGeneratedKeys) throws SQLException {
            return connection.prepareStatement(sql, autoGeneratedKeys);
        }
    
        @Override
        public PreparedStatement prepareStatement(String sql, int[] columnIndexes) throws SQLException {
            return connection.prepareStatement(sql, columnIndexes);
        }
    
        @Override
        public PreparedStatement prepareStatement(String sql, String[] columnNames) throws SQLException {
            return connection.prepareStatement(sql, columnNames);
        }
    
        @Override
        public Clob createClob() throws SQLException {
            return connection.createClob();
        }
    
        @Override
        public Blob createBlob() throws SQLException {
            return connection.createBlob();
        }
    
        @Override
        public NClob createNClob() throws SQLException {
            return connection.createNClob();
        }
    
        @Override
        public SQLXML createSQLXML() throws SQLException {
            return connection.createSQLXML();
        }
    
        @Override
        public boolean isValid(int timeout) throws SQLException {
            return connection.isValid(timeout);
        }
    
        @Override
        public void setClientInfo(String name, String value) throws SQLClientInfoException {
            connection.setClientInfo(name, value);
        }
    
        @Override
        public void setClientInfo(Properties properties) throws SQLClientInfoException {
            connection.setClientInfo(properties);
        }
    
        @Override
        public String getClientInfo(String name) throws SQLException {
            return connection.getClientInfo(name);
        }
    
        @Override
        public Properties getClientInfo() throws SQLException {
            return connection.getClientInfo();
        }
    
        @Override
        public Array createArrayOf(String typeName, Object[] elements) throws SQLException {
            return connection.createArrayOf(typeName, elements);
        }
    
        @Override
        public Struct createStruct(String typeName, Object[] attributes) throws SQLException {
            return connection.createStruct(typeName, attributes);
        }
    
        @Override
        public void setSchema(String schema) throws SQLException {
            connection.setSchema(schema);
        }
    
        @Override
        public String getSchema() throws SQLException {
            return connection.getSchema();
        }
    
        @Override
        public void abort(Executor executor) throws SQLException {
            connection.abort(executor);
        }
    
        @Override
        public void setNetworkTimeout(Executor executor, int milliseconds) throws SQLException {
            connection.setNetworkTimeout(executor, milliseconds);
        }
    
        @Override
        public int getNetworkTimeout() throws SQLException {
            return connection.getNetworkTimeout();
        }
    
        @Override
        public <T> T unwrap(Class<T> iface) throws SQLException {
            return connection.unwrap(iface);
        }
    
        @Override
        public boolean isWrapperFor(Class<?> iface) throws SQLException {
            return connection.isWrapperFor(iface);
        }
    
        public Connection getConnection() {
            return connection;
        }
    
        public void setConnection(Connection connection) {
            this.connection = connection;
        }
    }
    
    

    事务上下文

    package com.xxx.mq.trx.core;
    
    import java.util.HashMap;
    import java.util.Map;
    
    /**
     * @Description TODO
     * @Author 姚仲杰
     * @Date 2020/12/28 11:42
     */
    public class TrxContext {
    
        private ThreadLocal<Map<String,String>> threadLocal=new ThreadLocal<Map<String,String>>(){
            @Override
            protected Map<String, String> initialValue() {
                return new HashMap<String, String>();
            }
        };
    
        public String put(String key, String value) {
            return threadLocal.get().put(key, value);
        }
    
        public String get(String key) {
            return threadLocal.get().get(key);
        }
    
        public String remove(String key) {
            return threadLocal.get().remove(key);
        }
    
        public Map<String, String> entries() {
            return threadLocal.get();
        }
    }
    
    

    事务上下文持有者缓存了trxId,以及全局事务连接列表等属性。

    package com.xxx.mq.trx.core;
    
    import java.util.List;
    import java.util.concurrent.ConcurrentHashMap;
    import org.slf4j.Logger;
    import org.slf4j.LoggerFactory;
    import org.springframework.util.CollectionUtils;
    import org.springframework.util.StringUtils;
    
    /**
     * @Description 
     * @Author 姚仲杰
     * @Date 2020/12/28 11:46
     */
    public class TrxContextHolder {
        private static final Logger LOGGER = LoggerFactory.getLogger(TrxContextHolder.class);
    
        public static final TrxContext TRX_CONTEXT_HOLDER=new TrxContext();
    
        private static volatile ConcurrentHashMap<String, List<ConnectionProxy>> connectionsMap =
            new ConcurrentHashMap<>();
    
        public static final String TRX_ID="TRX_ID";
    
        public static String getTrxId(){
            return TRX_CONTEXT_HOLDER.get(TRX_ID);
        }
    
        public static void setTrxId(String trxId){
            if (LOGGER.isDebugEnabled()) {
                LOGGER.debug("set trx_id:[{}]", trxId);
            }
            TRX_CONTEXT_HOLDER.put(TRX_ID, trxId);
    
        }
    
        public static String removeTrxId() {
            String trxId = TRX_CONTEXT_HOLDER.remove(TRX_ID);
            if (LOGGER.isDebugEnabled()) {
                LOGGER.debug("remove trx_id:[{}] ", trxId);
            }
            return trxId;
        }
    
        public static List<ConnectionProxy> getConnections(String trxId){
            if (StringUtils.isEmpty(trxId)){
                LOGGER.error("trx_id can not be empty");
                throw new IllegalArgumentException();
            }
            return connectionsMap.get(trxId);
        }
    
        public static void setConnections(String trxId,List<ConnectionProxy> connections){
            if (StringUtils.isEmpty(trxId)){
                LOGGER.error("trx_id can not be empty");
                throw new IllegalArgumentException();
            }
            if (CollectionUtils.isEmpty(connections)){
                LOGGER.error("connections can not be empty,require at least one connection");
                throw new IllegalArgumentException();
            }
            connectionsMap.put(trxId,connections);
        }
    
        public static void removeConnections(String trxId){
            if (StringUtils.isEmpty(trxId)){
                LOGGER.error("trx_id can not be empty");
                throw new IllegalArgumentException();
            }
            connectionsMap.remove(trxId);
        }
    }
    
    

    二阶段提交mq监听器

    package com.xxx.mq.trx.core;
    
    import com.alibaba.fastjson.JSON;
    import com.xxx.mq.trx.config.TransactionConst;
    import java.util.HashMap;
    import java.util.List;
    import java.util.Map;
    import org.apache.commons.collections.CollectionUtils;
    import org.apache.rocketmq.client.consumer.listener.ConsumeConcurrentlyContext;
    import org.apache.rocketmq.client.consumer.listener.ConsumeConcurrentlyStatus;
    import org.apache.rocketmq.client.consumer.listener.MessageListenerConcurrently;
    import org.apache.rocketmq.common.message.MessageExt;
    import org.apache.rocketmq.spring.annotation.ConsumeMode;
    import org.apache.rocketmq.spring.annotation.RocketMQMessageListener;
    import org.slf4j.Logger;
    import org.slf4j.LoggerFactory;
    
    /**
     * @Description TODO
     * @Author 姚仲杰
     * @Date 2021/1/4 11:18
     */
    @RocketMQMessageListener(consumeMode = ConsumeMode.CONCURRENTLY,topic = TransactionConst.TRX_TOPIC,consumerGroup = TransactionConst.TRX_GROUP)
    public class TransactionMassageListener implements MessageListenerConcurrently {
        public static final Logger LOGGER= LoggerFactory.getLogger(TransactionMassageListener.class);
        @Override
        public ConsumeConcurrentlyStatus consumeMessage(List<MessageExt> msgs,
            ConsumeConcurrentlyContext context) {
            LOGGER.info("receive global transaction message: {}",msgs);
            MessageExt messageExt = msgs.get(0);
            //如果本地获取不到事务等待连接直接返回消费成功,因为这是广播模式。
            try {
                String s = new String(messageExt.getBody(), "utf-8");
                Map map = JSON.parseObject(s, HashMap.class);
                String trxId= (String) map.get(TransactionConst.TRX_ID);
                int state= (int) map.get(trxId);
                List<ConnectionProxy> connections = TrxContextHolder.getConnections(trxId);
                if (!CollectionUtils.isEmpty(connections)){
                    try {
                        connections.forEach(cp -> cp.notify(state));
                    }finally {
                        TrxContextHolder.removeConnections(trxId);
                    }
                }
            }catch (Throwable e){
               return  ConsumeConcurrentlyStatus.RECONSUME_LATER;
            }
            return ConsumeConcurrentlyStatus.CONSUME_SUCCESS;
        }
    }
    
    

    dubbo事务id传播过滤器

    package com.xxx.mq.trx.integration.dubbo;
    
    import com.xxx.mq.trx.config.TransactionConst;
    import com.xxx.mq.trx.core.TrxContextHolder;
    import org.apache.dubbo.common.extension.Activate;
    import org.apache.dubbo.rpc.Filter;
    import org.apache.dubbo.rpc.Invocation;
    import org.apache.dubbo.rpc.Invoker;
    import org.apache.dubbo.rpc.Result;
    import org.apache.dubbo.rpc.RpcContext;
    import org.apache.dubbo.rpc.RpcException;
    import org.slf4j.Logger;
    import org.slf4j.LoggerFactory;
    
    /**
     * @Description 用户传递trx_id给下游服务,并将事务id绑定给本地线程变量
     * @Author 姚仲杰
     * @Date 2021/01/04 11:46
     */
    @Activate(group = {"provider", "consumer"}, order = 100)
    public class DubboTrxPropagationFilter implements Filter {
    
        private static final Logger LOGGER = LoggerFactory.getLogger(DubboTrxPropagationFilter.class);
    
        @Override
        public Result invoke(Invoker<?> invoker, Invocation invocation) throws RpcException {
            String trxId = TrxContextHolder.getTrxId();
            String rpcXid = RpcContext.getContext().getAttachment(TransactionConst.TRX_ID);
            if (LOGGER.isDebugEnabled()) {
                LOGGER.debug("trxId in TrxContext[{}] trxId in RpcContext[{}]", trxId, rpcXid);
            }
            boolean bind = false;
            if (trxId != null) {
                RpcContext.getContext().setAttachment(TransactionConst.TRX_ID, trxId);
            } else {
                if (rpcXid != null) {
                    TrxContextHolder.setTrxId(rpcXid);
                    bind = true;
                }
            }
            try {
                return invoker.invoke(invocation);
            } finally {
                if (bind) {
                    TrxContextHolder.removeTrxId();
                }
            }
        }
    
    }
    
    

    相关文章

      网友评论

        本文标题:基于MQ的2PC分布式事务

        本文链接:https://www.haomeiwen.com/subject/zhcezktx.html