美文网首页Java
mybatis自定义拦截器-数据权限过滤

mybatis自定义拦截器-数据权限过滤

作者: 无心火 | 来源:发表于2019-10-20 14:08 被阅读0次

       最近一段时间公司搞新项目,数据库orm选用了mybatis框架。使用一段时间mybaits后感觉比其他orm框架灵活好用,好处就不说了,网上一搜大把。本次主要讲下mybatis自定义拦截器功能的开发,通过拦截器可以解决项目中蛮多的问题​,虽然很多功能不用拦截器也可以实现,但使用自定义拦截器实现功能从我角度至少以下优点(1)灵活,解耦(2)统一控制 ,减少开发工作量,不用散落到每个业务功能点去实现。

        一般业务系统项目都涉及到数据权限的控制,此次结合本项目记录下基于mybatis拦截器实现数据权限的过滤,因为项目用到mybatis-plus的分页插件,数据权限拦截过滤的时机也要控制好,在分页拦截器之前先拦截修改sql,不然会导致查询出来的数据同分页统计出来数量不一致。


    拦截器基本知识


        Mybatis采用责任链模式,通过动态代理组织多个拦截器,通过这些拦截器可以改变mybatis的默认行为,编写自定义拦截器最好了解下它的原理,以便写出安全高效的插件。

     (1)拦截器均需要实现org.apache.ibatis.plugin.Interceptor 接口,对于自定义拦截器必须使用mybatis 提供的注解来指明我们要拦截的是四类中的哪一个类接口。

    具体规则如下:

     a:Intercepts 标识我的类是一个拦截器

     b:Signature 则是指明我们的拦截器需要拦截哪一个接口的哪一个方法;type对应四类接口中的某一个,比如是 Executor;method对应接口中的哪类方法,比如 Executor 的 update 方法;args 对应接口中的哪一个方法,比如 Executor 中 query 因为重载原因,方法有多个,args 就是指明参数类型,从而确定是哪一个方法。

    @Intercepts({

        @Signature(type = Executor.class, method = "update", args = {MappedStatement.class, Object.class})

    })

    (2) mybatis 拦截器默认可拦截的类型四种,即四种接口类型 Executor、StatementHandler、ParameterHandler 和 ResultSetHandler,对于我们的自定义拦截器必须使用 mybatis 提供的注解来指明我们要拦截的是四类中的哪一个类接口。

    (3)拦截器顺序:

     不同类型拦截器的顺序Executor -> ParameterHandler -> StatementHandler ->ResultSetHandler

      同类型的拦截器的不同对象拦截顺序则根据 mybatis 核心配置文件的配置位置,拦截顺序是 从上往下,在mybatis 核心配置文件中需要配置我们的 plugin 


    数据权限过滤


       1.实现业务需求的数据过滤,在用户访问数据库时进行权限判断并改造sql,达到限制低权限用户访问数据的目的

       2.采用技术:mybatis拦截器,java自定义注解,反射,开源jsqlparser

       3.核心业务流程图

    4.代码实现

    (1)创建自定义注解

    ```

    import java.lang.annotation.Documented;

    import java.lang.annotation.ElementType;

    import java.lang.annotation.Inherited;

    import java.lang.annotation.Retention;

    import java.lang.annotation.RetentionPolicy;

    import java.lang.annotation.Target;

    /**

    * 数据权限注解

    *

    */

    @Documented

    @Target( value = { ElementType.TYPE, ElementType.METHOD } )

    @Retention( RetentionPolicy.RUNTIME )

    @Inherited

    public @interface DataAuth

    {

    /**

    * 追加sql的方法名

    * @return

    */

    public String method() default "whereSql";

    /**

    * 表别名

    * @return

    */

    public String tableAlias() default "";

    }

    ```

    (2)mapper方法增加权限注解

    ```

    import com.baomidou.mybatisplus.core.mapper.BaseMapper;

    import com.baomidou.mybatisplus.extension.plugins.pagination.Page;

    import org.apache.ibatis.annotations.Param;

    import java.util.List;

    public interface TestMapper extends BaseMapper<Test> {

        /**

        * 增加权限注解

        *

        */

        @DataAuth(tableAlias = "o")

        List<TestEntity> listData(TestQuery testQuery);

    }

    ```

    (3)创建自定义拦截器

    ```

    import net.sf.jsqlparser.expression.Expression;

    import net.sf.jsqlparser.expression.Parenthesis;

    import net.sf.jsqlparser.expression.operators.conditional.AndExpression;

    import net.sf.jsqlparser.parser.CCJSqlParserManager;

    import net.sf.jsqlparser.parser.CCJSqlParserUtil;

    import net.sf.jsqlparser.statement.select.PlainSelect;

    import net.sf.jsqlparser.statement.select.Select;

    import org.apache.commons.lang3.StringUtils;

    import org.apache.ibatis.executor.Executor;

    import org.apache.ibatis.mapping.BoundSql;

    import org.apache.ibatis.mapping.MappedStatement;

    import org.apache.ibatis.mapping.SqlCommandType;

    import org.apache.ibatis.mapping.SqlSource;

    import org.apache.ibatis.plugin.Interceptor;

    import org.apache.ibatis.plugin.Intercepts;

    import org.apache.ibatis.plugin.Invocation;

    import org.apache.ibatis.plugin.Plugin;

    import org.apache.ibatis.plugin.Signature;

    import org.apache.ibatis.reflection.DefaultReflectorFactory;

    import org.apache.ibatis.reflection.MetaObject;

    import org.apache.ibatis.reflection.factory.DefaultObjectFactory;

    import org.apache.ibatis.reflection.wrapper.DefaultObjectWrapperFactory;

    import org.apache.ibatis.session.ResultHandler;

    import org.apache.ibatis.session.RowBounds;

    import org.slf4j.Logger;

    import org.slf4j.LoggerFactory;

    import org.springframework.beans.BeansException;

    import org.springframework.context.ApplicationContext;

    import org.springframework.context.ApplicationContextAware;

    import org.springframework.stereotype.Component;

    import java.io.StringReader;

    import java.lang.reflect.Method;

    import java.util.Map;

    import java.util.Properties;

    /**

    * 数据权限拦截器

    * 根据各个微服务,继承DataAuthService增加不同的where语句

    *

    */

    @Component

    @Intercepts({@Signature(method = "query",type = Executor.class,args =  {

            MappedStatement.class, Object.class, RowBounds.class, ResultHandler.class}

        )

    })

    public class MybatisDataAuthInterceptor implements Interceptor,

        ApplicationContextAware {

        private final static Logger logger = LoggerFactory.getLogger(MybatisDataAuthInterceptor.class);

        private static ApplicationContext context;

        @Override

        public void setApplicationContext(ApplicationContext applicationContext)

            throws BeansException {

            context = applicationContext;

        }

        @Override

        public Object intercept(Invocation arg0) throws Throwable {

            MappedStatement mappedStatement = (MappedStatement) arg0.getArgs()[0];

            // 只对查询sql拦截

            if (!SqlCommandType.SELECT.equals(mappedStatement.getSqlCommandType())) {

                return arg0.proceed();

            }

            // String mSql = sql;

            // 注解逻辑判断 添加注解了才拦截追加

            Class<?> classType = Class.forName(mappedStatement.getId()

                                                              .substring(0,

                        mappedStatement.getId().lastIndexOf(".")));

            String mName = mappedStatement.getId()

                                          .substring(mappedStatement.getId()

                                                                    .lastIndexOf(".") +

                    1, mappedStatement.getId().length()); //

            for (Method method : classType.getDeclaredMethods()) {

                if (method.isAnnotationPresent(DataAuth.class) &&

                        mName.equals(method.getName())) {

                    /**

                    * 查找标识了该注解 的实现 类

                    */

                    Map<String, Object> beanMap = context.getBeansWithAnnotation(DataAuth.class);

                    if ((beanMap != null) && (beanMap.entrySet().size() > 0)) {

                        for (Map.Entry<String, Object> entry : beanMap.entrySet()) {

                            DataAuth action = method.getAnnotation(DataAuth.class);

                            if (StringUtils.isEmpty(action.method())) {

                                break;

                            }

                            try {

                                Method md = entry.getValue().getClass()

                                                .getMethod(action.method(),

                                        new Class[] { String.class });

                                /**

                                * 反射获取业务 sql

                                */

                                String whereSql = (String) md.invoke(context.getBean(

                                            entry.getValue().getClass()),

                                        new Object[] { action.tableAlias() });

                                if (!StringUtils.isEmpty(whereSql) &&

                                        !"null".equalsIgnoreCase(whereSql)) {

                                    Object parameter = null;

                                    if (arg0.getArgs().length > 1) {

                                        parameter = arg0.getArgs()[1];

                                    }

                                    BoundSql boundSql = mappedStatement.getBoundSql(parameter);

                                    MappedStatement newStatement = newMappedStatement(mappedStatement,

                                            new BoundSqlSqlSource(boundSql));

                                    MetaObject msObject = MetaObject.forObject(newStatement,

                                            new DefaultObjectFactory(),

                                            new DefaultObjectWrapperFactory(),

                                            new DefaultReflectorFactor());

                                    /**

                                    * 通过JSqlParser解析 原有sql,追加sql条件

                                    */

                                    CCJSqlParserManager parserManager = new CCJSqlParserManager();

                                    Select select = (Select) parserManager.parse(new StringReader(

                                                boundSql.getSql()));

                                    PlainSelect selectBody = (PlainSelect) select.getSelectBody();

                                    Expression whereExpression = CCJSqlParserUtil.parseCondExpression(whereSql);

                                    selectBody.setWhere(new AndExpression(

                                            selectBody.getWhere(),

                                            new Parenthesis(whereExpression)));

                                    /**

                                    * 修改sql

                                    */

                                    msObject.setValue("sqlSource.boundSql.sql",

                                        selectBody.toString());

                                    arg0.getArgs()[0] = newStatement;

                                    logger.info("Interceptor sql:" +

                                        selectBody.toString());

                                }

                            } catch (Exception e) {

                                logger.error(null, e);

                            }

                            break;

                        }

                    }

                }

            }

            return arg0.proceed();

        }

        private MappedStatement newMappedStatement(MappedStatement ms,

            SqlSource newSqlSource) {

            MappedStatement.Builder builder = new MappedStatement.Builder(ms.getConfiguration(),

                    ms.getId(), newSqlSource, ms.getSqlCommandType());

            builder.resource(ms.getResource());

            builder.fetchSize(ms.getFetchSize());

            builder.statementType(ms.getStatementType());

            builder.keyGenerator(ms.getKeyGenerator());

            if ((ms.getKeyProperties() != null) &&

                    (ms.getKeyProperties().length != 0)) {

                StringBuilder keyProperties = new StringBuilder();

                for (String keyProperty : ms.getKeyProperties()) {

                    keyProperties.append(keyProperty).append(",");

                }

                keyProperties.delete(keyProperties.length() - 1,

                    keyProperties.length());

                builder.keyProperty(keyProperties.toString());

            }

            builder.timeout(ms.getTimeout());

            builder.parameterMap(ms.getParameterMap());

            builder.resultMaps(ms.getResultMaps());

            builder.resultSetType(ms.getResultSetType());

            builder.cache(ms.getCache());

            builder.flushCacheRequired(ms.isFlushCacheRequired());

            builder.useCache(ms.isUseCache());

            return builder.build();

        }

        /**

        * 当目标类是Executor类型时,才包装目标类,否者直接返回目标本身,减少目标被代理的次数

        */

        @Override

        public Object plugin(Object target) {

            if (target instanceof Executor) {

                return Plugin.wrap(target, this);

            }

            return target;

        }

        @Override

        public void setProperties(Properties arg0) {

            // TODO Auto-generated method stub

        }

        class BoundSqlSqlSource implements SqlSource {

            private BoundSql boundSql;

            public BoundSqlSqlSource(BoundSql boundSql) {

                this.boundSql = boundSql;

            }

            @Override

            public BoundSql getBoundSql(Object parameterObject) {

                return boundSql;

            }

        }

    }

    ```

    (4)增加业务逻辑

    ```

    import com.baomidou.mybatisplus.core.toolkit.StringUtils;

    import com.winhong.wincloud.constant.RoleTypeJudge;

    import com.winhong.wincore.async.ThreadLocalHolder;

    import com.winhong.wincore.user.LoginUserHolder;

    import com.winhong.wincore.user.UserInfo;

    import org.slf4j.Logger;

    import org.slf4j.LoggerFactory;

    import org.springframework.stereotype.Service;

    @Service

    public abstract class AbstractDataAuthService {

        private static final Logger LOG = LoggerFactory.getLogger(AbstractDataAuthService.class);

        /**

        * 默认查询sql,根据角色不同追加不同业务查询条件

        *

        * @return

        */

        public String whereSql(String tableAlias) {

            if (!StringUtils.isEmpty(tableAlias)) {

                tableAlias = tableAlias + ".";

            }

            StringBuffer sql = new StringBuffer();

            //利用threadlocal获取用户角色信息

            UserInfo userInfo = LoginUserHolder.getUser();

            // 普通 用户

            if (RoleTypeJudge.isNormalUser(userInfo.getRoleTypeCode())) {

                sql.append(nomalUserSql(userInfo.getUserUuid(), tableAlias));

            }

            // 管理员

            else if (RoleTypeJudge.isManager(userInfo.getRoleTypeCode())) {

                sql.append(managerSql(tableAlias));

            } else {

            }

            return sql.toString();

        }

    }

    ```

    相关文章

      网友评论

        本文标题:mybatis自定义拦截器-数据权限过滤

        本文链接:https://www.haomeiwen.com/subject/vktamctx.html