核心思想
通过Mybatis的拦截器重写sql
拦截器部分代码
@Component
@Intercepts({@Signature(type = StatementHandler.class, method = "prepare", args = {Connection.class, Integer.class})})
@Slf4j
public class PermissionsInterceptor implements Interceptor {
@Autowired
private DataAuthHandlerContext dataAuthHandler;
@Override
public Object intercept(Invocation invocation) throws Throwable {
StatementHandler statementHandler = (StatementHandler) invocation.getTarget();
MetaObject metaObject = MetaObject.forObject(statementHandler, SystemMetaObject.DEFAULT_OBJECT_FACTORY,
SystemMetaObject.DEFAULT_OBJECT_WRAPPER_FACTORY, new DefaultReflectorFactory());
//先拦截到RoutingStatementHandler,里面有个StatementHandler类型的delegate变量,其实现类是BaseStatementHandler,然后就到BaseStatementHandler的成员变量mappedStatement
MappedStatement mappedStatement = (MappedStatement) metaObject.getValue("delegate.mappedStatement");
//id为执行的mapper方法的全路径名,如com.uv.dao.UserMapper.insertUser
//String id = mappedStatement.getId();
//sql语句类型 select、delete、insert、update
//String sqlCommandType = mappedStatement.getSqlCommandType().toString();
BoundSql boundSql = statementHandler.getBoundSql();
//获取到原始sql语句
String sql = boundSql.getSql();
String[] permissionsValue = null;
//注解逻辑判断 添加注解了才拦截
Class<?> classType = Class.forName(mappedStatement.getId().substring(0, mappedStatement.getId().lastIndexOf(".")));
String mName = mappedStatement.getId().substring(mappedStatement.getId().lastIndexOf(".") + 1, mappedStatement.getId().length());
for (Method method : classType.getDeclaredMethods()) {
if (method.isAnnotationPresent(DataPermissions.class) &&
(mName.equals(method.getName()) || mName.equals(method.getName() + "_COUNT"))) {
DataPermissions permissions = method.getAnnotation(DataPermissions.class);
permissionsValue = permissions.value();
}
}
if (permissionsValue != null && permissionsValue.length > 0) {
String userCode = getUserCode(boundSql);
//如果为管理员 不需要进行权限校验
if (NotifyConstant.IS_MANAGER.equals(userCode)) {
return invocation.proceed();
}
List<PermissionCondition> permissionConditions = new ArrayList<>();
for (String permissions : permissionsValue) {
String[] strs = permissions.split(":");
if (strs.length >= 2) {
String tableName = strs[0];
String columnName = strs[1];
if (StrUtil.isEmpty(tableName) || StrUtil.isEmpty(columnName)) {
continue;
}
String operate = strs[2];
String realOperate = StrUtil.isNotEmpty(operate) ? operate : PermissionCondition.CONDITION_IN;
List<String> values = dataAuthHandler.getDbColumnPermissions(userCode, tableName, columnName, realOperate);
if (CollectionUtils.isEmpty(values)) {
//数据权限为空
throw new BaseException("90037", "用户::" + userCode + "无表::"
+ tableName + ":字段:" + columnName + "::" + operate + "::操作权限");
}
PermissionCondition condition = new PermissionCondition();
condition.setTableName(tableName);
condition.setColumnName(columnName);
condition.setValues(values);
condition.setOperate(realOperate);
permissionConditions.add(condition);
}
}
List<PermissionCondition> realConditions = permissionConditions.stream()
.filter(e -> CollUtil.isNotEmpty(e.getValues())).collect(Collectors.toList());
//生成增强权限控制后的SQL语句
//后续优化方向 DataPermissionsSqlParseUtil 这个可以做成可扩展的
String newSql = DataPermissionsSqlParseUtil.convert(sql, realConditions);
if (StringUtils.isNotBlank(newSql)) {
Field field = boundSql.getClass().getDeclaredField("sql");
ReflectionUtils.makeAccessible(field);
field.set(boundSql, newSql);
}
}
return invocation.proceed();
}
private String getUserCode(BoundSql boundSql) {
final Object parameterObject = boundSql.getParameterObject();
String userCode = null;
if (parameterObject != null) {
HashMap<String, Object> paramList = (HashMap<String, Object>) parameterObject;
final Object userObj = paramList.get("userCode");
if (userObj != null) {
userCode = (String) userObj;
}
}
return userCode;
}
@Override
public Object plugin(Object target) {
if (target instanceof StatementHandler) {
return Plugin.wrap(target, this);
} else {
return target;
}
}
@Override
public void setProperties(Properties properties) {
}
}
权限类代码
@Component
@Slf4j
public class DataAuthHandlerContext {
private final Map<String, DataAuthHandler> methodMap = new ConcurrentHashMap<>();
@Autowired
public DataAuthHandlerContext(Map<String, DataAuthHandler> map) {
this.methodMap.clear();
map.forEach((k, v) -> this.methodMap.put(v.getTableName(), v));
}
/**
* @param operator 操作用户
* @param tableName 数据库表名称
* @param dbColumn 列名称
* @param conditionType 操作类型 参考PermissionCondition.operate
* @return {@link List<String>}
* @throws
* @desc 后续优化方向
* (1)方法很多的时候 反射性能可能会很慢 可以缓存一段时间
* 可以参考 guava的 com.google.common.eventbus.SubscriberRegistry#flattenHierarchyCache
* (2)可以在编译期就检查是否有对应的DbColumn 方法
* @date 2021-3-26 15:57
*/
public List<String> getDbColumnPermissions(String operator, String tableName, String dbColumn, String conditionType) {
if (StrUtil.isNotEmpty(operator) && StrUtil.isNotEmpty(tableName)
&& StrUtil.isNotEmpty(dbColumn) && StrUtil.isNotEmpty(conditionType)) {
DataAuthHandler DataAuthHandler = methodMap.get(tableName);
Method[] declaredMethods = DataAuthHandler.getClass().getDeclaredMethods();
for (Method declaredMethod : declaredMethods) {
DbColumn annotation = declaredMethod.getAnnotation(DbColumn.class);
String name = (annotation == null) ? null : annotation.name();
String annoConditionType = (annotation == null) ? null : annotation.conditionType();
if (StrUtil.isNotEmpty(name) && name.equals(dbColumn) && annoConditionType.equals(conditionType)) {
try {
List<String> result = (List<String>) declaredMethod.invoke(DataAuthHandler, operator);
return result;
} catch (Exception e) {
log.error("DataAuthHandlerContext.getDbColumnPermissions,执行异常,{}", e);
}
}
}
}
return new ArrayList<>();
}
}
自定义注解
@Target(ElementType.METHOD)
@Retention(RetentionPolicy.RUNTIME)
public @interface DbColumn {
/**
* 列名称 需要和数据库一致
* @return
*/
String name() ;
/**
* 条件类型 参考PermissionCondition.operator
* @return
*/
String conditionType();
}
数据权限处理接口
* 数据权限 处理接口
* 不同的表对应不同的实现类
* 不同的实现类里 每个数据库字段,每个不同处理条件对应一个方法
* 通过@DbColumn 注解来标识 进行反射调用
*/
public interface DataAuthHandler {
/**
* @desc 获取要处理的 数据库表名称
* @return {@link String} 表名称
* @throws
*/
String getTableName();
}
辅助类
/**
* 后续优化
* (1)扩展更多的condition
* (2)拼接sql时 都用了字符串(不一定和数据库类型一致) 可能因为隐式函数转换 导致索引失效
* (3)可以把整个插件封装成一个springboot的start 开箱即用
*
* @author YEJUNMENG
*/
@Data
public class PermissionCondition {
public final static String CONDITION_EQUALS = " = ";
public final static String CONDITION_LESS = " < ";
public final static String CONDITION_GREATER = " > ";
public final static String CONDITION_LESS_OR_EQUALS = " <= ";
public final static String CONDITION_GREATER_OR_EQUALS = " >= ";
public final static String CONDITION_LIKE = " LIKE ";
//mysql不支持 start_with语法
public final static String CONDITION_START_WITH = " START_WITH ";
public final static String CONDITION_MYSQL_START_WITH = "MYSQL LIKE ";
public final static String CONDITION_IN = " in ";
private String tableName;
private String columnName;
private String operate;
private List<String> values;
public char[] toWhere(Object alias) {
final String operate = getOperate();
final String columnName = getColumnName();
String result = new String();
if (CONDITION_IN.equals(operate)) {
result = buildInConditionSql(alias, operate, columnName);
} else if (CONDITION_LIKE.equals(operate)) {
result = buildLikeSql(alias, operate, columnName);
} else if (CONDITION_EQUALS.equals(operate)
|| CONDITION_LESS_OR_EQUALS.equals(operate)
|| CONDITION_GREATER_OR_EQUALS.equals(operate)
|| CONDITION_LESS.equals(operate)
|| CONDITION_GREATER.equals(operate)
|| CONDITION_START_WITH.equals(operate)) {
result = buildCommonSql(alias, operate, columnName);
} else if (CONDITION_MYSQL_START_WITH.equals(operate)) {
result = buildMysqlStartWithSql(alias, operate, columnName);
}
return result.toCharArray();
}
private String buildCommonSql(Object alias, String operate, String columnName) {
String result = new String();
final List<String> values = getValues();
if (CollUtil.isNotEmpty(values)) {
String value = values.get(0);
StringBuilder sb = new StringBuilder();
if (alias != null) {
sb.append(alias).append(".");
}
sb.append(columnName);
sb.append(operate).append("\'")
.append(value)
.append("\'");
result = sb.toString();
}
return result;
}
private String buildLikeSql(Object alias, String operate, String columnName) {
String result = new String();
final List<String> values = getValues();
if (CollUtil.isNotEmpty(values)) {
String value = values.get(0);
StringBuilder sb = new StringBuilder();
if (alias != null) {
sb.append(alias).append(".");
}
sb.append(columnName);
sb.append(operate).append("\'")
.append("%").append(value).append("%")
.append("\'");
result = sb.toString();
}
return result;
}
private String buildMysqlStartWithSql(Object alias, String operate, String columnName) {
String result = new String();
final List<String> values = getValues();
if (CollUtil.isNotEmpty(values)) {
String value = values.get(0);
StringBuilder sb = new StringBuilder();
if (alias != null) {
sb.append(alias).append(".");
}
sb.append(columnName);
sb.append(" like ").append("\'")
.append(value).append("%")
.append("\'");
result = sb.toString();
}
return result;
}
private String buildInConditionSql(Object alias, String operate, String columnName) {
String result;
StringBuilder sb = new StringBuilder();
if (alias != null) {
sb.append(alias).append(".");
}
sb.append(columnName);
sb.append(operate).append(" (");
for (String value : values) {
sb.append("\"").append(value).append("\"").append(",");
}
String str = sb.toString();
final int i = str.lastIndexOf(",");
String newStr = str;
if (i > 0) {
newStr = str.substring(0, i);
}
result = newStr + " )";
return result;
}
}
Sql重写
DataPermissionsSqlParseUtil
/**
* @author YEJUNMENG
* 权限控制SQL处理类
*/
@Slf4j
public class DataPermissionsSqlParseUtil {
private static Pattern paramPattern = Pattern.compile("\\$\\d*");
private static String LIMIT_OFFSET_STR = "LIMIT ? OFFSET ?";
private static String LIMIT_STR = "LIMIT ?, ? ";
private static final String[] PARSE_KEY_WORD = {"year"};
public static String convert(String sql, List<PermissionCondition> permissionConditions) {
try {
sql = before(sql);
log.info("ORI SQL :{}", sql);
//支持多表关联查询
Map<String, String> aliasMap = select_table_alias(sql);
for (PermissionCondition permissionCondition : permissionConditions) {
String alias = aliasMap.get(permissionCondition.getTableName());
char[] chars = permissionCondition.toWhere(alias);
sql = SQLUtils.addCondition(sql, new String(chars), null);
}
String newSql = sql;
Matcher matcher = paramPattern.matcher(newSql);
StringBuffer retStr = new StringBuffer();
while (matcher.find()) {
matcher.appendReplacement(retStr, "?");
}
matcher.appendTail(retStr);
newSql = retStr.toString();
if (newSql.indexOf(LIMIT_OFFSET_STR) > 0) {
newSql = newSql.replace(LIMIT_OFFSET_STR, LIMIT_STR);
}
if (newSql.indexOf("(INTERVAL") >= 0) {
newSql = newSql.replaceAll("(\\()(INTERVAL \\d MONTH)(\\))", "$2");
}
if (newSql.indexOf("CAST(1 AS INTERVAL YEAR)") >= 0) {
newSql = newSql.replaceAll("\\(CAST\\(1 AS INTERVAL YEAR\\)\\)", "INTERVAL 1 YEAR");
}
newSql = after(newSql);
log.info("NEW SQL :{}", newSql);
return newSql;
} catch (Exception e) {
sql = after(sql);
e.printStackTrace();
}
return sql;
}
public static String before(String sql) {
if (sql.indexOf("INTERVAL 1 YEAR") < 0) {
for (String kw : PARSE_KEY_WORD) {
int index = sql.toUpperCase().indexOf(kw.toUpperCase());
while (index >= 0) {
String[] ks = kw.split("");
String w = StringUtils.join(ks, "_");
sql = sql.substring(0, index) + w + sql.substring(index + kw.length());
index = sql.toUpperCase().indexOf(kw.toUpperCase());
}
}
}
return sql;
}
public static String after(String sql) {
for (String kw : PARSE_KEY_WORD) {
String[] ks = kw.split("");
String w = StringUtils.join(ks, "_");
int index = sql.toUpperCase().indexOf(w.toUpperCase());
while (index >= 0) {
sql = sql.substring(0, index) + kw + sql.substring(index + w.length());
index = sql.toUpperCase().indexOf(w.toUpperCase());
}
}
return sql;
}
/**
* @Description: 查询sql字段
* @Param: [sql]
* @Return: java.util.List<java.lang.String>
**/
public static List<String> select_items(String sql)
throws JSQLParserException {
CCJSqlParserManager parserManager = new CCJSqlParserManager();
Select select = (Select) parserManager.parse(new StringReader(sql));
PlainSelect plain = (PlainSelect) select.getSelectBody();
List<SelectItem> selectitems = plain.getSelectItems();
List<String> str_items = new ArrayList<String>();
if (selectitems != null) {
for (SelectItem selectitem : selectitems) {
str_items.add(selectitem.toString());
}
}
return str_items;
}
public static Map<String, String> select_table_alias(String sql) {
Map<String, String> map = new HashMap<>();
try {
Select select = (Select) CCJSqlParserUtil.parse(sql);
SelectBody selectBody = select.getSelectBody();
PlainSelect plainSelect = (PlainSelect) selectBody;
Table table = (Table) plainSelect.getFromItem();
if (table.getAlias() != null) {
map.put(table.getName(), table.getAlias().getName());
}
if (plainSelect.getJoins() != null) {
for (Join join : plainSelect.getJoins()) {
Table table1 = (Table) join.getRightItem();
if (table1.getAlias() != null) {
map.put(table1.getName(), table1.getAlias().getName());
}
}
}
} catch (Exception e) {
e.printStackTrace();
}
return map;
}
/**
* @Description: 查询表名table
* @Param: [sql]
* @Return: java.util.List<java.lang.String>
**/
public static List<String> select_table(String sql)
throws JSQLParserException {
Statement statement = CCJSqlParserUtil.parse(sql);
Select selectStatement = (Select) statement;
TablesNamesFinder tablesNamesFinder = new TablesNamesFinder();
return tablesNamesFinder.getTableList(selectStatement);
}
/**
* @Description: 查询join
* @Param: [sql]
* @Return: java.util.List<java.lang.String>
**/
public static List<String> select_join(String sql)
throws JSQLParserException {
Statement statement = CCJSqlParserUtil.parse(sql);
Select selectStatement = (Select) statement;
PlainSelect plain = (PlainSelect) selectStatement.getSelectBody();
List<Join> joinList = plain.getJoins();
List<String> tablewithjoin = new ArrayList<String>();
if (joinList != null) {
for (Join join : joinList) {
join.setLeft(true);//是否开放left jion中的left
tablewithjoin.add(join.toString());
}
}
return tablewithjoin;
}
/**
* @Description: 查询where
* @Param: [sql]
* @Return: java.lang.String
**/
public static String select_where(String sql)
throws JSQLParserException {
CCJSqlParserManager parserManager = new CCJSqlParserManager();
Select select = (Select) parserManager.parse(new StringReader(sql));
PlainSelect plain = (PlainSelect) select.getSelectBody();
Expression where_expression = plain.getWhere();
return where_expression.toString();
}
/**
* @Description: 对where条件解析并返回结果
* @Param: [sql, metadata:是否开启原数据]
* @Return: java.util.List<java.lang.Object>
**/
public static List<Map<String, Object>> parseWhere(String sql) {
try {
Select select = (Select) CCJSqlParserUtil.parse(sql);
SelectBody selectBody = select.getSelectBody();
PlainSelect plainSelect = (PlainSelect) selectBody;
Expression expr = CCJSqlParserUtil.parseCondExpression(plainSelect.getWhere().toString());
List<Map<String, Object>> arrList = new ArrayList<>();
expr.accept(new ExpressionDeParser() {
int depth = 0;
@Override
public void visit(Parenthesis parenthesis) {
depth++;
parenthesis.getExpression().accept(this);
depth--;
}
@Override
public void visit(OrExpression orExpression) {
visitBinaryExpr(orExpression, "OR");
}
@Override
public void visit(AndExpression andExpression) {
visitBinaryExpr(andExpression, "AND");
}
private void visitBinaryExpr(BinaryExpression expr, String operator) {
Map<String, Object> map = new HashMap<>();
if (!(expr.getLeftExpression() instanceof OrExpression)
&& !(expr.getLeftExpression() instanceof AndExpression)
&& !(expr.getLeftExpression() instanceof Parenthesis)) {
getBuffer();
}
expr.getLeftExpression().accept(this);
map.put("leftExpression", expr.getLeftExpression());
map.put("operator", operator);
if (!(expr.getRightExpression() instanceof OrExpression)
&& !(expr.getRightExpression() instanceof AndExpression)
&& !(expr.getRightExpression() instanceof Parenthesis)) {
getBuffer();
}
expr.getRightExpression().accept(this);
map.put("rightExpression", expr.getRightExpression());
arrList.add(map);
}
});
return arrList;
} catch (JSQLParserException e) {
e.printStackTrace();
}
return null;
}
/**
* @Description: 完全解析where单个条件并返回
* @Param: [where]
* @Return: java.util.Map<java.lang.Object,java.lang.Object>
**/
public static Map<Object, Object> fullResolutionWhere(String where) {
Map<Object, Object> map = new HashMap<>();
try {
Expression expr = CCJSqlParserUtil.parseCondExpression(where);
expr.accept(new ExpressionVisitorAdapter() {
@Override
protected void visitBinaryExpression(BinaryExpression expr) {
if (expr instanceof ComparisonOperator) {
map.put("leftExpression", expr.getLeftExpression());
map.put("operate", expr.getStringExpression());
map.put("rightExpression", expr.getRightExpression());
}
super.visitBinaryExpression(expr);
}
});
//暂时无法解析IS NOT NULL 和 IS NULL
if (CollectionUtils.isEmpty(map) && (where.toUpperCase().contains("IS NOT NULL") || where.toUpperCase().contains("IS NULL"))) {
map.put("leftExpression", where.substring(0, where.lastIndexOf("IS")));
map.put("operate", null);
map.put("rightExpression", where.substring(where.lastIndexOf("IS"), where.length()));
}
} catch (Exception e) {
e.printStackTrace();
}
return map;
}
/**
* @Description: 查询 group by
* @Param: [sql]
* @Return: java.util.List<java.lang.String>
**/
public static List<String> select_groupby(String sql)
throws JSQLParserException {
CCJSqlParserManager parserManager = new CCJSqlParserManager();
Select select = (Select) parserManager.parse(new StringReader(sql));
PlainSelect plain = (PlainSelect) select.getSelectBody();
List<Expression> GroupByColumnReferences = plain.getGroupBy().getGroupByExpressions();
List<String> str_groupby = new ArrayList<String>();
if (GroupByColumnReferences != null) {
for (Expression groupByColumnReference : GroupByColumnReferences) {
str_groupby.add(groupByColumnReference.toString());
}
}
return str_groupby;
}
/**
* @Description: 查询order by
* @Param: [sql]
* @Return: java.util.List<java.lang.String>
**/
public static List<String> select_orderby(String sql)
throws JSQLParserException {
CCJSqlParserManager parserManager = new CCJSqlParserManager();
Select select = (Select) parserManager.parse(new StringReader(sql));
PlainSelect plain = (PlainSelect) select.getSelectBody();
List<OrderByElement> OrderByElements = plain.getOrderByElements();
List<String> str_orderby = new ArrayList<String>();
if (OrderByElements != null) {
for (OrderByElement orderByElement : OrderByElements) {
str_orderby.add(orderByElement.toString());
}
}
return str_orderby;
}
/**
* @Description: 子查询
* @Param: [selectBody]
* @Return: java.util.Map
* @Demo: select * from (select userid from (select userid from a)a) a
**/
public static Map select_subselect(SelectBody selectBody) throws JSQLParserException {
Map<String, String> map = new HashMap<String, String>();
if (selectBody instanceof PlainSelect) {
List<SelectItem> selectItems = ((PlainSelect) selectBody).getSelectItems();
for (SelectItem selectItem : selectItems) {
if (selectItem.toString().contains("(") && selectItem.toString().contains(")")) {
map.put("selectItemsSubselect", selectItem.toString());
}
}
Expression where = ((PlainSelect) selectBody).getWhere();
if (where != null) {
String whereStr = where.toString();
if (whereStr.contains("(") && whereStr.contains(")")) {
int firstIndex = whereStr.indexOf("(");
int lastIndex = whereStr.lastIndexOf(")");
CharSequence charSequence = whereStr.subSequence(firstIndex, lastIndex + 1);
map.put("whereSubselect", charSequence.toString());
}
}
FromItem fromItem = ((PlainSelect) selectBody).getFromItem();
if (fromItem instanceof SubSelect) {
map.put("fromItemSubselect", fromItem.toString());
}
} else if (selectBody instanceof WithItem) {
select_subselect(((WithItem) selectBody).getSelectBody());
}
return map;
}
/**
* @Description: 判断是否为多级子查询
* @Param: [selectBody]
* @Return: boolean
* @Demo: select * from (select userid from (select userid from a)a) a
**/
public static boolean isMultiSubSelect(SelectBody selectBody) {
if (selectBody instanceof PlainSelect) {
FromItem fromItem = ((PlainSelect) selectBody).getFromItem();
if (fromItem instanceof SubSelect) {
SelectBody subBody = ((SubSelect) fromItem).getSelectBody();
if (subBody instanceof PlainSelect) {
FromItem subFromItem = ((PlainSelect) subBody).getFromItem();
if (subFromItem instanceof SubSelect) {
return true;
}
}
}
}
return false;
}
public static void main(String[] args) {
String sql = "SELECT\n" +
"\tt.id AS notifyId,\n" +
"\tt.title,\n" +
"\tt.created_by,\n" +
"\tt.notice_date,\n" +
"\tp.id AS cardId,\n" +
"\tp.`name` AS cardName,\n" +
"\tn.dept_code AS deptCode,\n" +
"\tn.dept_name AS deptName,\n" +
"\tk.object_id AS objId,\n" +
"\tk.object_name AS objName,\n" +
"\tk.object_type AS objType\n" +
"FROM\n" +
"\tw_notification t\n" +
"INNER JOIN w_notif_card m ON t.id = m.notification_id\n" +
"INNER JOIN w_card p ON m.warn_card_id = p.id\n" +
"INNER JOIN w_notif_dept n ON t.id = n.notification_id\n" +
"INNER JOIN w_notif_object k ON t.id = k.notification_id\n" +
"WHERE\n" +
"\tt.id = 3419854225570924552";
List<PermissionCondition> permissionConditions = new ArrayList<>();
PermissionCondition dept = new PermissionCondition();
dept.setTableName("w_notif_dept");
dept.setColumnName("dept_code");
dept.setOperate(PermissionCondition.CONDITION_EQUALS);
dept.setValues(Arrays.asList("2"));
PermissionCondition notify = new PermissionCondition();
notify.setTableName("w_notification");
notify.setColumnName("created_by");
notify.setOperate(PermissionCondition.CONDITION_MYSQL_START_WITH);
notify.setValues(Arrays.asList("umrunner"));
permissionConditions.add(dept);
permissionConditions.add(notify);
Map<String, String> aliasMap = select_table_alias(sql);
for (PermissionCondition permissionCondition : permissionConditions) {
String alias = aliasMap.get(permissionCondition.getTableName());
char[] chars = permissionCondition.toWhere(alias);
sql = SQLUtils.addCondition(sql, new String(chars), null);
}
System.out.println(sql);
}
}
网友评论