美文网首页Druiddruid 源码之旅
[druid 源码解析] 10 wallFilter解析

[druid 源码解析] 10 wallFilter解析

作者: AndyWei123 | 来源:发表于2021-11-20 00:05 被阅读0次

    接下来,我们将讲解 druid pool 包以外的包解析,这次我们先从 wallFilter 开始说起,我们先来写一个 wallFilter 的 example,首先我们需要在配置文件中开启 wallFilter ,接着我们从配置开始说起,配置信息如下:

    Spring:
      datasource:
        druid:
          filter:
            wall:
           enabled: true
              config:
                select-where-alway-true-check: true
    

    首先需要开启 wallFilter ,然后配置 config,这里配置了 select-where-alway-true-check: true 就是检查永真条件的 where 语句,除了以上的配置外,还包可以配置如下属性:

    proerties
    properties

    我们先来测试一下 select-where-alway-true-check: true 属性,我们的 Mybatis 的 Mapper 文件中配置了 wehere 1 = 1 , 这个条件,然后进行测试,会发现如下报错信息:

    java.sql.SQLException: sql injection violation, dbType mysql, druid-version 1.2.8, not terminal sql, token WHEN : select
        ......
        from TABLES
        when 1 = 1
        at com.alibaba.druid.wall.WallFilter.checkInternal(WallFilter.java:859) ~[druid-1.2.8.jar:1.2.8]
        at com.alibaba.druid.wall.WallFilter.connection_prepareStatement(WallFilter.java:295) ~[druid-1.2.8.jar:1.2.8]
        at com.alibaba.druid.filter.FilterChainImpl.connection_prepareStatement(FilterChainImpl.java:568) ~[druid-1.2.8.jar:1.2.8]
        at com.alibaba.druid.filter.FilterAdapter.connection_prepareStatement(FilterAdapter.java:930) ~[druid-1.2.8.jar:1.2.8]
    

    我们可以看到,这里会直接报错,SQL 注入异常,我们根据堆栈位置,找出 WallFilter 的入口位置, 如下:

    
        @Override
        public PreparedStatementProxy connection_prepareStatement(FilterChain chain, ConnectionProxy connection, String sql)
                                                                                                                            throws SQLException {
            return chain.connection_prepareStatement(connection, sql);
        }
    
    

    这里我们之前有讲过,这里是责任链模式,这里会先加载所有的 Filter 然后每个 Filter 通过递归的方式调用,我们再来看一下 WallFilter 的执行方法:

     @Override
        public PreparedStatementProxy connection_prepareStatement(FilterChain chain, ConnectionProxy connection, String sql)
                                                                                                                            throws SQLException {
            String dbType = connection.getDirectDataSource().getDbType();
            WallContext context = WallContext.create(dbType);
            try {
                WallCheckResult result = checkInternal(sql);
                context.setWallUpdateCheckItems(result.getUpdateCheckItems());
                sql = result.getSql();
                PreparedStatementProxy stmt = chain.connection_prepareStatement(connection, sql);
                setSqlStatAttribute(stmt);
                return stmt;
            } finally {
                WallContext.clearContext();
            }
        }
    

    首先是根据 dbType 生成 WallContext ,这个步骤没有太复杂的程序,主要是将 dbType 设置到 WallContext 中, 接着调用 checkInternal 方法:

    private WallCheckResult checkInternal(String sql) throws SQLException {
            WallCheckResult checkResult = provider.check(sql);
            List<Violation> violations = checkResult.getViolations();
            if (violations.size() > 0) {
                ......
            }
            return checkResult;
        }
    
    

    其实主要是调用 provider 来检查,我们看一下其实这个 provider 是在 WallFilter init 的时候进行初始化的,我们先看一下 init 方法:

     case mysql:
                case oceanbase:
                case drds:
                case mariadb:
                case h2:
                case presto:
                case trino:
                    if (config == null) {
                        config = new WallConfig(MySqlWallProvider.DEFAULT_CONFIG_DIR);
                    }
    
                    provider = new MySqlWallProvider(config);
                    break;
    ...
    

    这里传进去的就是我们之前配置的 WallFilter 相关的 config 配置信息,我们再来看一下检查的具体逻辑:

     private WallCheckResult checkInternal(String sql) {
            checkCount.incrementAndGet();
    
            WallContext context = WallContext.current();
    
            if (config.isDoPrivilegedAllow() && ispPrivileged()) {
                WallCheckResult checkResult = new WallCheckResult();
                checkResult.setSql(sql);
                return checkResult;
            }
    
            // first step, check whiteList
            boolean mulltiTenant = config.getTenantTablePattern() != null && config.getTenantTablePattern().length() > 0;
            if (!mulltiTenant) {
                WallCheckResult checkResult = checkWhiteAndBlackList(sql);
                if (checkResult != null) {
                    checkResult.setSql(sql);
                    return checkResult;
                }
            }
    
            hardCheckCount.incrementAndGet();
            final List<Violation> violations = new ArrayList<Violation>();
            List<SQLStatement> statementList = new ArrayList<SQLStatement>();
            boolean syntaxError = false;
            boolean endOfComment = false;
            try {
                SQLStatementParser parser = createParser(sql);
                parser.getLexer().setCommentHandler(WallCommentHandler.instance);
    
                if (!config.isCommentAllow()) {
                    parser.getLexer().setAllowComment(false); // deny comment
                }
                if (!config.isCompleteInsertValuesCheck()) {
                    parser.setParseCompleteValues(false);
                    parser.setParseValuesSize(config.getInsertValuesCheckSize());
                }
                
                parser.parseStatementList(statementList);
    
                final Token lastToken = parser.getLexer().token();
                if (lastToken != Token.EOF && config.isStrictSyntaxCheck()) {
                    violations.add(new IllegalSQLObjectViolation(ErrorCode.SYNTAX_ERROR, "not terminal sql, token "
                                                                                         + lastToken, sql));
                }
                endOfComment = parser.getLexer().isEndOfComment();
            } catch (NotAllowCommentException e) {
                violations.add(new IllegalSQLObjectViolation(ErrorCode.COMMENT_STATEMENT_NOT_ALLOW, "comment not allow", sql));
                incrementCommentDeniedCount();
            } catch (ParserException e) {
                syntaxErrorCount.incrementAndGet();
                syntaxError = true;
                if (config.isStrictSyntaxCheck()) {
                    violations.add(new SyntaxErrorViolation(e, sql));
                }
            } catch (Exception e) {
                if (config.isStrictSyntaxCheck()) {
                    violations.add(new SyntaxErrorViolation(e, sql));
                }
            }
    
            if (statementList.size() > 1 && !config.isMultiStatementAllow()) {
                violations.add(new IllegalSQLObjectViolation(ErrorCode.MULTI_STATEMENT, "multi-statement not allow", sql));
            }
    
            WallVisitor visitor = createWallVisitor();
            visitor.setSqlEndOfComment(endOfComment);
    
            if (statementList.size() > 0) {
                boolean lastIsHint = false;
                for (int i=0; i<statementList.size(); i++) {
                    SQLStatement stmt = statementList.get(i);
                    if ((i == 0 || lastIsHint) && stmt instanceof MySqlHintStatement) {
                        lastIsHint = true;
                        continue;
                    }
                    try {
                        stmt.accept(visitor);
                    } catch (ParserException e) {
                        violations.add(new SyntaxErrorViolation(e, sql));
                    }
                }
            }
    
            if (visitor.getViolations().size() > 0) {
                violations.addAll(visitor.getViolations());
            }
    
            Map<String, WallSqlTableStat> tableStat = context.getTableStats();
    
            boolean updateCheckHandlerEnable = false;
            {
                WallUpdateCheckHandler updateCheckHandler = config.getUpdateCheckHandler();
                if (updateCheckHandler != null) {
                    for (SQLStatement stmt : statementList) {
                        if (stmt instanceof SQLUpdateStatement) {
                            SQLUpdateStatement updateStmt = (SQLUpdateStatement) stmt;
                            SQLName table = updateStmt.getTableName();
                            if (table != null) {
                                String tableName = table.getSimpleName();
                                Set<String> updateCheckColumns = config.getUpdateCheckTable(tableName);
                                if (updateCheckColumns != null && updateCheckColumns.size() > 0) {
                                    updateCheckHandlerEnable = true;
                                    break;
                                }
                            }
                        }
                    }
                }
            }
    
            WallSqlStat sqlStat = null;
            if (violations.size() > 0) {
                violationCount.incrementAndGet();
    
                if ((!updateCheckHandlerEnable) && sql.length() < MAX_SQL_LENGTH) {
                    sqlStat = addBlackSql(sql, tableStat, context.getFunctionStats(), violations, syntaxError);
                }
            } else {
                if ((!updateCheckHandlerEnable) && sql.length() < MAX_SQL_LENGTH) {
                    boolean selectLimit = false;
                    if (config.getSelectLimit() > 0) {
                        for (SQLStatement stmt : statementList) {
                            if (stmt instanceof SQLSelectStatement) {
                                selectLimit = true;
                                break;
                            }
                        }
                    }
    
                    if (!selectLimit) {
                        sqlStat = addWhiteSql(sql, tableStat, context.getFunctionStats(), syntaxError);
                    }
                }
            }
            
            if(sqlStat == null && updateCheckHandlerEnable){
                sqlStat = new WallSqlStat(tableStat, context.getFunctionStats(), violations, syntaxError);
            }
    
            Map<String, WallSqlTableStat> tableStats = null;
            Map<String, WallSqlFunctionStat> functionStats = null;
            if (context != null) {
                tableStats = context.getTableStats();
                functionStats = context.getFunctionStats();
                recordStats(tableStats, functionStats);
            }
    
            WallCheckResult result;
            if (sqlStat != null) {
                context.setSqlStat(sqlStat);
                result = new WallCheckResult(sqlStat, statementList);
            } else {
                result = new WallCheckResult(null, violations, tableStats, functionStats, statementList, syntaxError);
            }
    
            String resultSql;
            if (visitor.isSqlModified()) {
                resultSql = SQLUtils.toSQLString(statementList, dbType);
            } else {
                resultSql = sql;
            }
            result.setSql(resultSql);
    
            result.setUpdateCheckItems(visitor.getUpdateCheckItems());
    
            return result;
        }
    

    主要做了以下几个事情:
    1、检查这个 SQL 是否在白名单中,假如是就直接返回结果。
    2、对 SQL 进行解析,生成 SQLStatement 列表,因为可能存在复合语句。
    3、调用 SQLStatementaccept 方法,将 config 生成的 WallVisitor 放进去,然后检查是否会抛出异常,假如会,就代表存在语法错误,记录到 Result 中。

    相关文章

      网友评论

        本文标题:[druid 源码解析] 10 wallFilter解析

        本文链接:https://www.haomeiwen.com/subject/bhcytrtx.html