美文网首页springboot技术
springboot配置多数据源后mybatis拦截器失效

springboot配置多数据源后mybatis拦截器失效

作者: 那些年搬过的砖 | 来源:发表于2018-01-18 23:22 被阅读0次
    关键字:springcloud、mybatis、多数据源负载均衡、拦截器动态分页
    
    
    配置文件是通过springcloudconfig远程分布式配置。采用阿里Druid数据源。并支持一主多从的读写分离。分页组件通过拦截器拦截带有page后缀的方法名,动态的设置total总数。
    1. 解析配置文件初始化数据源
    @Configuration
    public class DataSourceConfiguration {
        /**
         * 数据源类型
         */
        @Value("${spring.datasource.type}")
        private Class<? extends DataSource> dataSourceType;
    
        /**
         * 主数据源配置
         *
         * @return
         */
        @Bean(name = "masterDataSource", destroyMethod = "close")
        @Primary
        @ConfigurationProperties(prefix = "spring.datasource")
        public DataSource masterDataSource() {
            DataSource source = DataSourceBuilder.create().type(dataSourceType).build();
            return source;
        }
    
        /**
         * 从数据源配置
         *
         * @return
         */
        @Bean(name = "slaveDataSource0")
        @ConfigurationProperties(prefix = "spring.slave0")
        public DataSource slaveDataSource0() {
            DataSource source = DataSourceBuilder.create().type(dataSourceType).build();
            return source;
        }
    
        /**
         * 从数据源集合
         *
         * @return
         */
        @Bean(name = "slaveDataSources")
        public List<DataSource> slaveDataSources() {
            List<DataSource> slaveDataSources = new ArrayList();
            slaveDataSources.add(slaveDataSource0());
            return slaveDataSources;
        }
    }
    
    2. 定义数据源枚举类型
    public enum DataSourceType {
        master("master", "master"), slave("slave", "slave");
        private String type;
    
        private String name;
    
        DataSourceType(String type, String name) {
            this.type = type;
            this.name = name;
        }
    
        public String getType() {
            return type;
        }
    
        public void setType(String type) {
            this.type = type;
        }
    
        public String getName() {
            return name;
        }
    
        public void setName(String name) {
            this.name = name;
        }
    }
    
    3. TheadLocal保存数据源类型
    public class DataSourceContextHolder {
        private static final ThreadLocal<String> local = new ThreadLocal<String>();
    
        public static ThreadLocal<String> getLocal() {
            return local;
        }
    
        public static void slave() {
            local.set(DataSourceType.slave.getType());
        }
    
        public static void master() {
            local.set(DataSourceType.master.getType());
        }
    
        public static String getJdbcType() {
            return local.get();
        }
    
        public static void clearDataSource(){
            local.remove();
        }
    }
    
    4. 自定义sqlSessionProxy,并将数据源填充到DataSourceRoute
    @Configuration
    @ConditionalOnClass({EnableTransactionManagement.class})
    @Import({DataSourceConfiguration.class})
    public class DataSourceSqlSessionFactory {
        private Logger logger = Logger.getLogger(DataSourceSqlSessionFactory.class);
    
        @Value("${spring.datasource.type}")
        private Class<? extends DataSource> dataSourceType;
    
        @Value("${mybatis.mapper-locations}")
        private String mapperLocations;
    
        @Value("${mybatis.type-aliases-package}")
        private String aliasesPackage;
    
        @Value("${slave.datasource.number}")
        private int dataSourceNumber;
    
        @Resource(name = "masterDataSource")
        private DataSource masterDataSource;
    
        @Resource(name = "slaveDataSources")
        private List<DataSource> slaveDataSources;
    
        @Bean
        @ConditionalOnMissingBean
        public SqlSessionFactory sqlSessionFactory() throws Exception {
            logger.info("======================= init sqlSessionFactory");
            SqlSessionFactoryBean sqlSessionFactoryBean = new SqlSessionFactoryBean();
            sqlSessionFactoryBean.setDataSource(roundRobinDataSourceProxy());
            PathMatchingResourcePatternResolver resolver = new PathMatchingResourcePatternResolver();
            sqlSessionFactoryBean.setMapperLocations(resolver.getResources(mapperLocations));
            sqlSessionFactoryBean.setTypeAliasesPackage(aliasesPackage);
            sqlSessionFactoryBean.getObject().getConfiguration().setMapUnderscoreToCamelCase(true);
            return sqlSessionFactoryBean.getObject();
        }
    
        @Bean(name = "roundRobinDataSourceProxy")
        public AbstractRoutingDataSource roundRobinDataSourceProxy() {
            logger.info("======================= init robinDataSourceProxy");
            DataSourceRoute proxy = new DataSourceRoute(dataSourceNumber);
            Map<Object, Object> targetDataSources = new HashMap();
            targetDataSources.put(DataSourceType.master.getType(), masterDataSource);
            if(null != slaveDataSources) {
                for(int i=0; i<slaveDataSources.size(); i++){
                    targetDataSources.put(i, slaveDataSources.get(i));
                }
            }
            proxy.setDefaultTargetDataSource(masterDataSource);
            proxy.setTargetDataSources(targetDataSources);
            return proxy;
        }
    }
    
    5. 自定义路由
    public class DataSourceRoute extends AbstractRoutingDataSource {
    
        private Logger logger = Logger.getLogger(DataSourceRoute.class);
    
        private final int dataSourceNumber;
        
        public DataSourceRoute(int dataSourceNumber) {
            this.dataSourceNumber = dataSourceNumber;
        }
    
        @Override
        protected Object determineCurrentLookupKey() {
            String typeKey = DataSourceContextHolder.getJdbcType();
            logger.info("==================== swtich dataSource:" + typeKey);
            if (typeKey.equals(DataSourceType.master.getType())) {
                return DataSourceType.master.getType();
            }else{
                //从数据源随机分配
                Random random = new Random();
                int slaveDsIndex = random.nextInt(dataSourceNumber);
                return slaveDsIndex;
            }
        }
    }
    
    6. 定义切面,dao层定义切面
    @Aspect
    @Component
    public class DataSourceAop {
    
        private Logger logger = Logger.getLogger(DataSourceAop.class);
    
        @Before("execution(* com.dbq.iot.mapper..*.get*(..)) || execution(* com.dbq.iot.mapper..*.isExist*(..)) " +
                "|| execution(* com.dbq.iot.mapper..*.select*(..)) || execution(* com.dbq.iot.mapper..*.count*(..)) " +
                "|| execution(* com.dbq.iot.mapper..*.list*(..)) || execution(* com.dbq.iot.mapper..*.query*(..))" +
                "|| execution(* com.dbq.iot.mapper..*.find*(..))|| execution(* com.dbq.iot.mapper..*.search*(..))")
        public void setSlaveDataSourceType(JoinPoint joinPoint) {
            DataSourceContextHolder.slave();
            logger.info("=========slave, method:" + joinPoint.getSignature().getName());
        }
    
        @Before("execution(* com.dbq.iot.mapper..*.add*(..)) || execution(* com.dbq.iot.mapper..*.del*(..))" +
                "||execution(* com.dbq.iot.mapper..*.upDate*(..)) || execution(* com.dbq.iot.mapper..*.insert*(..))" +
                "||execution(* com.dbq.iot.mapper..*.create*(..)) || execution(* com.dbq.iot.mapper..*.update*(..))" +
                "||execution(* com.dbq.iot.mapper..*.delete*(..)) || execution(* com.dbq.iot.mapper..*.remove*(..))" +
                "||execution(* com.dbq.iot.mapper..*.save*(..)) || execution(* com.dbq.iot.mapper..*.relieve*(..))" +
                "|| execution(* com.dbq.iot.mapper..*.edit*(..))")
        public void setMasterDataSourceType(JoinPoint joinPoint) {
            DataSourceContextHolder.master();
            logger.info("=========master, method:" + joinPoint.getSignature().getName());
        }
    }
    
    7. 最后在写库增加事务管理
    @Configuration
    @Import({DataSourceConfiguration.class})
    public class DataSouceTranscation extends DataSourceTransactionManagerAutoConfiguration {
    
        private Logger logger = Logger.getLogger(DataSouceTranscation.class);
    
        @Resource(name = "masterDataSource")
        private DataSource masterDataSource;
    
        /**
         * 配置事务管理器
         *
         * @return
         */
        @Bean(name = "transactionManager")
        public DataSourceTransactionManager transactionManagers() {
            logger.info("===================== init transactionManager");
            return new DataSourceTransactionManager(masterDataSource);
        }
    
    
    }
    
    8. 在配置文件中增加数据源配置
    spring.datasource.name=writedb
    spring.datasource.url=jdbc:mysql://192.168.0.1/master?useUnicode=true&amp;characterEncoding=utf8&amp;autoReconnect=true&amp;failOverReadOnly=false
    spring.datasource.username=root
    spring.datasource.password=1234
    spring.datasource.type=com.alibaba.druid.pool.DruidDataSource
    spring.datasource.driver-class-name=com.mysql.jdbc.Driver
    spring.datasource.filters=stat
    spring.datasource.initialSize=20
    spring.datasource.minIdle=20
    spring.datasource.maxActive=200
    spring.datasource.maxWait=60000
    
    #从库的数量
    slave.datasource.number=1
    
    spring.slave0.name=readdb
    spring.slave0.url=jdbc:mysql://192.168.0.2/slave?useUnicode=true&amp;characterEncoding=utf8&amp;autoReconnect=true&amp;failOverReadOnly=false
    spring.slave0.username=root
    spring.slave0.password=1234
    spring.slave0.type=com.alibaba.druid.pool.DruidDataSource
    spring.slave0.driver-class-name=com.mysql.jdbc.Driver
    spring.slave0.filters=stat
    spring.slave0.initialSize=20
    spring.slave0.minIdle=20
    spring.slave0.maxActive=200
    spring.slave0.maxWait=60000
    
    这样就实现了在springcloud框架下的读写分离,并且支持多个从库的负载均衡(简单的通过随机分配,也有网友通过算法实现平均分配,具体做法是通过一个线程安全的自增长Integer类型,取余实现。个人觉得没大必要。如果有大神有更好的方法可以一起探讨。)
    Mabatis分页配置可通过dao层的拦截器对特定方法进行拦截,拦截后添加自己的逻辑代码,比如计算total等,具体代码如下(参考了网友的代码,主要是通过@Intercepts注解):
    @Intercepts({@Signature(type = StatementHandler.class, method = "prepare", args = {Connection.class, Integer.class})})
    public class PageInterceptor implements Interceptor {
        private static final Log logger = LogFactory.getLog(PageInterceptor.class);
        private static final ObjectFactory DEFAULT_OBJECT_FACTORY = new DefaultObjectFactory();
        private static final ObjectWrapperFactory DEFAULT_OBJECT_WRAPPER_FACTORY = new DefaultObjectWrapperFactory();
        private static final ReflectorFactory DEFAULT_REFLECTOR_FACTORY = new DefaultReflectorFactory();
        private static String defaultDialect = "mysql"; // 数据库类型(默认为mysql)
        private static String defaultPageSqlId = ".*Page$"; // 需要拦截的ID(正则匹配)
        private String dialect = ""; // 数据库类型(默认为mysql)
        private String pageSqlId = ""; // 需要拦截的ID(正则匹配)
    
        @Override
        public Object intercept(Invocation invocation) throws Throwable {
            StatementHandler statementHandler = (StatementHandler) invocation.getTarget();
            MetaObject metaStatementHandler = MetaObject.forObject(statementHandler, DEFAULT_OBJECT_FACTORY,
                    DEFAULT_OBJECT_WRAPPER_FACTORY,DEFAULT_REFLECTOR_FACTORY);
            // 分离代理对象链(由于目标类可能被多个拦截器拦截,从而形成多次代理,通过下面的两次循环可以分离出最原始的的目标类)
            while (metaStatementHandler.hasGetter("h")) {
                Object object = metaStatementHandler.getValue("h");
                metaStatementHandler = MetaObject.forObject(object, DEFAULT_OBJECT_FACTORY, DEFAULT_OBJECT_WRAPPER_FACTORY,DEFAULT_REFLECTOR_FACTORY);
            }
            // 分离最后一个代理对象的目标类
            while (metaStatementHandler.hasGetter("target")) {
                Object object = metaStatementHandler.getValue("target");
                metaStatementHandler = MetaObject.forObject(object, DEFAULT_OBJECT_FACTORY, DEFAULT_OBJECT_WRAPPER_FACTORY,DEFAULT_REFLECTOR_FACTORY);
            }
            Configuration configuration = (Configuration) metaStatementHandler.getValue("delegate.configuration");
            if (null == dialect || "".equals(dialect)) {
                logger.warn("Property dialect is not setted,use default 'mysql' ");
                dialect = defaultDialect;
            }
            if (null == pageSqlId || "".equals(pageSqlId)) {
                logger.warn("Property pageSqlId is not setted,use default '.*Page$' ");
                pageSqlId = defaultPageSqlId;
            }
            MappedStatement mappedStatement = (MappedStatement) metaStatementHandler.getValue("delegate.mappedStatement");
            // 只重写需要分页的sql语句。通过MappedStatement的ID匹配,默认重写以Page结尾的MappedStatement的sql
            if (mappedStatement.getId().matches(pageSqlId)) {
                BoundSql boundSql = (BoundSql) metaStatementHandler.getValue("delegate.boundSql");
                Object parameterObject = boundSql.getParameterObject();
                if (parameterObject == null) {
                    throw new NullPointerException("parameterObject is null!");
                } else {
                    PageParameter page = (PageParameter) metaStatementHandler
                            .getValue("delegate.boundSql.parameterObject.page");
                    String sql = boundSql.getSql();
                    // 重写sql
                    String pageSql = buildPageSql(sql, page);
                    metaStatementHandler.setValue("delegate.boundSql.sql", pageSql);
                    metaStatementHandler.setValue("delegate.rowBounds.offset", RowBounds.NO_ROW_OFFSET);
                    metaStatementHandler.setValue("delegate.rowBounds.limit", RowBounds.NO_ROW_LIMIT);
                    Connection connection = (Connection) invocation.getArgs()[0];
                    // 重设分页参数里的总页数等
                    setPageParameter(sql, connection, mappedStatement, boundSql, page);
                }
            }
            // 将执行权交给下一个拦截器
            return invocation.proceed();
        }
    
        /**
         * @param sql
         * @param connection
         * @param mappedStatement
         * @param boundSql
         * @param page
         */
        private void setPageParameter(String sql, Connection connection, MappedStatement mappedStatement,
                                      BoundSql boundSql, PageParameter page) {
            // 记录总记录数
            String countSql = "select count(0) from (" + sql + ") as total";
            PreparedStatement countStmt = null;
            ResultSet rs = null;
            try {
                countStmt = connection.prepareStatement(countSql);
                BoundSql countBS = new BoundSql(mappedStatement.getConfiguration(), countSql,
                        boundSql.getParameterMappings(), boundSql.getParameterObject());
    
                Field metaParamsField = ReflectUtil.getFieldByFieldName(boundSql, "metaParameters");
                if (metaParamsField != null) {
                    try {
                        MetaObject mo = (MetaObject) ReflectUtil.getValueByFieldName(boundSql, "metaParameters");
                        ReflectUtil.setValueByFieldName(countBS, "metaParameters", mo);
                    } catch (SecurityException | NoSuchFieldException | IllegalArgumentException
                            | IllegalAccessException e) {
                        // TODO Auto-generated catch block
                         logger.error("Ignore this exception", e);
                    }
                }
                Field additionalField = ReflectUtil.getFieldByFieldName(boundSql, "additionalParameters");
                if (additionalField != null) {
                    try {
                        Map<String, Object> map = (Map<String, Object>) ReflectUtil.getValueByFieldName(boundSql, "additionalParameters");
                        ReflectUtil.setValueByFieldName(countBS, "additionalParameters", map);
                    } catch (SecurityException | NoSuchFieldException | IllegalArgumentException
                            | IllegalAccessException e) {
                        // TODO Auto-generated catch block
                        logger.error("Ignore this exception", e);
                    }
                }
    
                setParameters(countStmt, mappedStatement, countBS, boundSql.getParameterObject());
                rs = countStmt.executeQuery();
                int totalCount = 0;
                if (rs.next()) {
                    totalCount = rs.getInt(1);
                }
                page.setTotalCount(totalCount);
                int totalPage = totalCount / page.getPageSize() + ((totalCount % page.getPageSize() == 0) ? 0 : 1);
                page.setTotalPage(totalPage);
    
            } catch (SQLException e) {
                logger.error("Ignore this exception", e);
            } finally {
                try {
                    if (rs != null){
                        rs.close();
                    }
                } catch (SQLException e) {
                    logger.error("Ignore this exception", e);
                }
                try {
                    if (countStmt != null){
                        countStmt.close();
                    }
                } catch (SQLException e) {
                    logger.error("Ignore this exception", e);
                }
            }
    
        }
    
        /**
         * 对SQL参数(?)设值
         *
         * @param ps
         * @param mappedStatement
         * @param boundSql
         * @param parameterObject
         * @throws SQLException
         */
        private void setParameters(PreparedStatement ps, MappedStatement mappedStatement, BoundSql boundSql,
                                   Object parameterObject) throws SQLException {
            ParameterHandler parameterHandler = new DefaultParameterHandler(mappedStatement, parameterObject, boundSql);
            parameterHandler.setParameters(ps);
        }
    
        /**
         * 根据数据库类型,生成特定的分页sql
         *
         * @param sql
         * @param page
         * @return
         */
        private String buildPageSql(String sql, PageParameter page) {
            if (page != null) {
                StringBuilder pageSql = new StringBuilder();
                pageSql = buildPageSqlForMysql(sql,page);
                return pageSql.toString();
            } else {
                return sql;
            }
        }
    
        /**
         * mysql的分页语句
         *
         * @param sql
         * @param page
         * @return String
         */
        public StringBuilder buildPageSqlForMysql(String sql, PageParameter page) {
            StringBuilder pageSql = new StringBuilder(100);
            String beginrow = String.valueOf((page.getCurrentPage() - 1) * page.getPageSize());
            pageSql.append(sql);
            pageSql.append(" limit " + beginrow + "," + page.getPageSize());
            return pageSql;
        }
    
        @Override
        public Object plugin(Object target) {
            if (target instanceof StatementHandler) {
                return Plugin.wrap(target, this);
            } else {
                return target;
            }
        }
    
        @Override
        public void setProperties(Properties properties) {
        }
    
    }
    
    这里碰到一个比较有趣的问题,就是sql如果是foreach参数,在拦截后无法注入。需要加入以下代码才可以(有得资料上只提到重置metaParameters)。
    Field metaParamsField = ReflectUtil.getFieldByFieldName(boundSql, "metaParameters");
    if (metaParamsField != null) {
        try {
            MetaObject mo = (MetaObject) ReflectUtil.getValueByFieldName(boundSql, "metaParameters");
            ReflectUtil.setValueByFieldName(countBS, "metaParameters", mo);
        } catch (SecurityException | NoSuchFieldException | IllegalArgumentException
                | IllegalAccessException e) {
            // TODO Auto-generated catch block
             logger.error("Ignore this exception", e);
        }
    }
    Field additionalField = ReflectUtil.getFieldByFieldName(boundSql, "additionalParameters");
    if (additionalField != null) {
        try {
            Map<String, Object> map = (Map<String, Object>) ReflectUtil.getValueByFieldName(boundSql, "additionalParameters");
            ReflectUtil.setValueByFieldName(countBS, "additionalParameters", map);
        } catch (SecurityException | NoSuchFieldException | IllegalArgumentException
                | IllegalAccessException e) {
            // TODO Auto-generated catch block
            logger.error("Ignore this exception", e);
        }
    }
    
    读写分离倒是写好了,但是发现增加了mysql一主多从的读写分离后,此分页拦截器直接失效。
    最后分析原因是因为,我们在做主从分离时,自定义了SqlSessionFactory,导致此拦截器没有注入。
    在上面第4步中,DataSourceSqlSessionFactory中注入拦截器即可,具体代码如下
    通过注解引入拦截器类:
    @Import({DataSourceConfiguration.class,PageInterceptor.class})
    
    注入拦截器
    @Autowired
        private PageInterceptor pageInterceptor;
    
    SqlSessionFactoryBean中设置拦截器
    sqlSessionFactoryBean.setPlugins(newInterceptor[]{pageInterceptor});
    
    这里碰到一个坑,就是设置plugins时必须在sqlSessionFactoryBean.getObject()之前。
    SqlSessionFactory在生成的时候就会获取plugins,并设置到Configuration中,如果在之后设置则不会注入。
    可跟踪源码看到:
    sqlSessionFactoryBean.getObject()
    
    public SqlSessionFactory getObject() throws Exception {
        if (this.sqlSessionFactory == null) {
          afterPropertiesSet();
        }
    
        return this.sqlSessionFactory;
    }
    
    public void afterPropertiesSet() throws Exception {
        notNull(dataSource, "Property 'dataSource' is required");
        notNull(sqlSessionFactoryBuilder, "Property 'sqlSessionFactoryBuilder' is required");
        state((configuration == null && configLocation == null) || !(configuration != null && configLocation != null),
                  "Property 'configuration' and 'configLocation' can not specified with together");
    
        this.sqlSessionFactory = buildSqlSessionFactory();
      }
    
    buildSqlSessionFactory()
    if (!isEmpty(this.plugins)) {
          for (Interceptor plugin : this.plugins) {
            configuration.addInterceptor(plugin);
            if (LOGGER.isDebugEnabled()) {
              LOGGER.debug("Registered plugin: '" + plugin + "'");
            }
          }
        }
    
    最后贴上正确的配置代码(DataSourceSqlSessionFactory代码片段)
    @Bean
    @ConditionalOnMissingBean
    public SqlSessionFactory sqlSessionFactory() throws Exception {
            logger.info("======================= init sqlSessionFactory");
            SqlSessionFactoryBean sqlSessionFactoryBean = new SqlSessionFactoryBean();
            sqlSessionFactoryBean.setPlugins(new Interceptor[]{pageInterceptor});
            sqlSessionFactoryBean.setDataSource(roundRobinDataSourceProxy());
            PathMatchingResourcePatternResolver resolver = new PathMatchingResourcePatternResolver();
            sqlSessionFactoryBean.setMapperLocations(resolver.getResources(mapperLocations));
            sqlSessionFactoryBean.setTypeAliasesPackage(aliasesPackage);
            SqlSessionFactory sqlSessionFactory = sqlSessionFactoryBean.getObject();
            sqlSessionFactory.getConfiguration().setMapUnderscoreToCamelCase(true);
    
            return sqlSessionFactory;
    }
    

    相关文章

      网友评论

        本文标题:springboot配置多数据源后mybatis拦截器失效

        本文链接:https://www.haomeiwen.com/subject/gzsgoxtx.html