美文网首页
【解析SQL模板-1】Mybatis的SQL模板组合成可运行的S

【解析SQL模板-1】Mybatis的SQL模板组合成可运行的S

作者: 小胖学编程 | 来源:发表于2024-09-02 14:59 被阅读0次

    背景

    实现平台化的mybatis能力,即在页面上输入mybatis的SQL模板,并传入参数,最终解析成可运行的SQL。

    实现原理

    引入依赖:

    <dependency>
        <groupId>org.mybatis</groupId>
        <artifactId>mybatis</artifactId>
        <version>3.5.7</version>
    </dependency>
    

    mybatis的SQL生成器:

    1. 解析mybatis模板,生成预编译的SQL;
    2. 解析预编译SQL,参数替换?;
    @Slf4j
    public class MybatisGenerator {
    
        private static final String HEAD = "<?xml version=\"1.0\" encoding=\"UTF-8\"?>"
                + "<!DOCTYPE mapper PUBLIC \"-//mybatis.org//DTD Mapper 3.0//EN\" \"http://mybatis"
                + ".org/dtd/mybatis-3-mapper.dtd\">"
                + "<mapper namespace=\"customGenerator\">"
                + "<select id=\"selectData\" parameterType=\"map\" resultType=\"map\">\n";
    
        private static final String FOOT = "\n</select></mapper>";
    
        private static final LoadingCache<String, MappedStatement> mappedStatementCache = CacheBuilder.newBuilder()
                .refreshAfterWrite(1, TimeUnit.DAYS)
                .build(new CacheLoader<String, MappedStatement>() {
                    @Override
                    public MappedStatement load(@NotNull String key) {
                        Configuration configuration = new Configuration();
                        configuration.setShrinkWhitespacesInSql(true);
                        String sourceSQL = HEAD + key + FOOT;
                        XMLMapperBuilder xmlMapperBuilder =
                                new XMLMapperBuilder(IOUtils.toInputStream(sourceSQL, Charset.forName("UTF-8")),
                                        configuration, null,
                                        null);
                        xmlMapperBuilder.parse();
                        return xmlMapperBuilder.getConfiguration().getMappedStatement("selectData");
                    }
                });
    
        //生成完整SQL
        public static String generateDsl(SQLConfig apiConfig, Map<String, Object> conditions) {
            String sql = apiConfig.getSqlTemplate();
            try {
                MappedStatement mappedStatement = mappedStatementCache.getUnchecked(sql);
                BoundSql boundSql = mappedStatement.getBoundSql(conditions);
                if (!boundSql.getParameterMappings().isEmpty()) {
                    List<PreparedStatementParameter> parameters = boundSql.getParameterMappings()
                            .stream().map(ParameterMapping::getProperty)
                            .map(param -> Optional.ofNullable(boundSql.getAdditionalParameter(param))
                                    .orElseGet(() -> conditions.get(param)))
                            .map(PreparedStatementParameter::fromObject)
                            .collect(Collectors.toList());
                    //解析占位符,获取到完整SQL
                    return PreparedStatementParser.parse(boundSql.getSql()).buildSql(parameters);
                } else {
                    return boundSql.getSql();
                }
            } catch (UncheckedExecutionException e) {
                throw e;
            }
        }
    
        @Data
        public static class SQLConfig {
            //SQL模板
            private String sqlTemplate;
        }
    }
    

    因为需要处理?(占位符),所以需要判断是否进行转义处理。

    
    public final class ValueFormatter {
        private static final Escaper ESCAPER = Escapers.builder()
                .addEscape('\\', "\\\\")
                .addEscape('\n', "\\n")
                .addEscape('\t', "\\t")
                .addEscape('\b', "\\b")
                .addEscape('\f', "\\f")
                .addEscape('\r', "\\r")
                .addEscape('\u0000', "\\0")
                .addEscape('\'', "\\'")
                .addEscape('`', "\\`")
                .build();
    
        public static final String NULL_MARKER = "\\N";
        private static final ThreadLocal<SimpleDateFormat> DATE_FORMAT =
                ThreadLocal.withInitial(() -> new SimpleDateFormat("yyyy-MM-dd"));
        private static final ThreadLocal<SimpleDateFormat> DATE_TIME_FORMAT =
                ThreadLocal.withInitial(() -> new SimpleDateFormat("yyyy-MM-dd HH:mm:ss"));
    
        public static String formatBytes(byte[] bytes) {
            if (bytes == null) {
                return null;
            } else {
                char[] hexArray =
                        new char[] {'0', '1', '2', '3', '4', '5', '6', '7', '8', '9', 'A', 'B', 'C', 'D', 'E', 'F'};
                char[] hexChars = new char[bytes.length * 4];
    
                for (int j = 0; j < bytes.length; ++j) {
                    int v = bytes[j] & 255;
                    hexChars[j * 4] = '\\';
                    hexChars[j * 4 + 1] = 'x';
                    hexChars[j * 4 + 2] = hexArray[v / 16];
                    hexChars[j * 4 + 3] = hexArray[v % 16];
                }
    
                return new String(hexChars);
            }
        }
    
        public static String formatInt(int myInt) {
            return Integer.toString(myInt);
        }
    
        public static String formatDouble(double myDouble) {
            return Double.toString(myDouble);
        }
    
        public static String formatChar(char myChar) {
            return Character.toString(myChar);
        }
    
        public static String formatLong(long myLong) {
            return Long.toString(myLong);
        }
    
        public static String formatFloat(float myFloat) {
            return Float.toString(myFloat);
        }
    
        public static String formatBigDecimal(BigDecimal myBigDecimal) {
            return myBigDecimal != null ? myBigDecimal.toPlainString() : "\\N";
        }
    
        public static String formatShort(short myShort) {
            return Short.toString(myShort);
        }
    
        public static String formatString(String myString) {
            return escape(myString);
        }
    
        public static String formatNull() {
            return "\\N";
        }
    
        public static String formatByte(byte myByte) {
            return Byte.toString(myByte);
        }
    
        public static String formatBoolean(boolean myBoolean) {
            return myBoolean ? "1" : "0";
        }
    
        public static String formatUUID(UUID x) {
            return x.toString();
        }
    
        public static String formatBigInteger(BigInteger x) {
            return x.toString();
        }
    
        public static String formatObject(Object x) {
            if (x == null) {
                return null;
            } else if (x instanceof Byte) {
                return formatInt(((Byte) x).intValue());
            } else if (x instanceof String) {
                return formatString((String) x);
            } else if (x instanceof BigDecimal) {
                return formatBigDecimal((BigDecimal) x);
            } else if (x instanceof Short) {
                return formatShort((Short) x);
            } else if (x instanceof Integer) {
                return formatInt((Integer) x);
            } else if (x instanceof Long) {
                return formatLong((Long) x);
            } else if (x instanceof Float) {
                return formatFloat((Float) x);
            } else if (x instanceof Double) {
                return formatDouble((Double) x);
            } else if (x instanceof byte[]) {
                return formatBytes((byte[]) x);
            } else if (x instanceof Boolean) {
                return formatBoolean((Boolean) x);
            } else if (x instanceof UUID) {
                return formatUUID((UUID) x);
            } else if (x instanceof BigInteger) {
                return formatBigInteger((BigInteger) x);
            } else {
                return String.valueOf(x);
            }
        }
    
        public static boolean needsQuoting(Object o) {
            if (o == null) {
                return false;
            } else if (o instanceof Number) {
                return false;
            } else if (o instanceof Boolean) {
                return false;
            } else if (o.getClass().isArray()) {
                return false;
            } else {
                return !(o instanceof Collection);
            }
        }
    
        private static SimpleDateFormat getDateFormat() {
            return DATE_FORMAT.get();
        }
    
        private static SimpleDateFormat getDateTimeFormat() {
            return DATE_TIME_FORMAT.get();
        }
    
        public static String escape(String s) {
            return s == null ? "\\N" : ESCAPER.escape(s);
        }
    
        public static String quoteIdentifier(String s) {
            if (s == null) {
                throw new IllegalArgumentException("Can't quote null as identifier");
            } else {
                StringBuilder sb = new StringBuilder(s.length() + 2);
                sb.append('`');
                sb.append(ESCAPER.escape(s));
                sb.append('`');
                return sb.toString();
            }
        }
    }
    

    定义预编译的参数:

    public final class PreparedStatementParameter {
        private static final PreparedStatementParameter
                NULL_PARAM = new PreparedStatementParameter((String) null, false);
        private static final PreparedStatementParameter
                TRUE_PARAM = new PreparedStatementParameter("1", false);
        private static final PreparedStatementParameter
                FALSE_PARAM = new PreparedStatementParameter("0", false);
        private final String stringValue;
        private final boolean quoteNeeded;
    
        //判断是否转义
        public static PreparedStatementParameter fromObject(Object x) {
            return x == null ? NULL_PARAM : new PreparedStatementParameter(
                    ValueFormatter.formatObject(x),
                    ValueFormatter.needsQuoting(x));
        }
    
        public static PreparedStatementParameter nullParameter() {
            return NULL_PARAM;
        }
    
        public static PreparedStatementParameter boolParameter(boolean value) {
            return value ? TRUE_PARAM : FALSE_PARAM;
        }
    
        public PreparedStatementParameter(String stringValue, boolean quoteNeeded) {
            this.stringValue = stringValue == null ? "\\N" : stringValue;
            this.quoteNeeded = quoteNeeded;
        }
    
        //判断是否需要转义
        String getRegularValue() {
            return !"\\N".equals(this.stringValue) ? (this.quoteNeeded ? "'" + this.stringValue + "'" : this.stringValue)
                                                   : "null";
        }
    
        String getBatchValue() {
            return this.stringValue;
        }
    
        public String toString() {
            return this.stringValue;
        }
    }
    

    预编译解析器:将参数替换到占位符

    public class PreparedStatementParser {
    
        static final String PARAM_MARKER = "?";
        static final String NULL_MARKER = "\\N";
    
        private static final Pattern VALUES = Pattern.compile(
                "(?i)INSERT\\s+INTO\\s+.+VALUES\\s*\\(",
                Pattern.MULTILINE | Pattern.DOTALL);
    
        private List<List<String>> parameters;
        private List<String> parts;
        private boolean valuesMode;
    
        private PreparedStatementParser() {
            parameters = new ArrayList<>();
            parts = new ArrayList<>();
            valuesMode = false;
        }
    
        public static PreparedStatementParser parse(String sql) {
            return parse(sql, -1);
        }
    
        public static PreparedStatementParser parse(String sql, int valuesEndPosition) {
            if (StringUtils.isBlank(sql)) {
                throw new IllegalArgumentException("SQL may not be blank");
            }
            PreparedStatementParser parser = new PreparedStatementParser();
            parser.parseSQL(sql, valuesEndPosition);
            return parser;
        }
    
        List<List<String>> getParameters() {
            return Collections.unmodifiableList(parameters);
        }
    
        List<String> getParts() {
            return Collections.unmodifiableList(parts);
        }
    
        boolean isValuesMode() {
            return valuesMode;
        }
    
        private void reset() {
            parameters.clear();
            parts.clear();
            valuesMode = false;
        }
    
        private void parseSQL(String sql, int valuesEndPosition) {
            reset();
            List<String> currentParamList = new ArrayList<String>();
            boolean afterBackSlash = false;
            boolean inQuotes = false;
            boolean inBackQuotes = false;
            boolean inSingleLineComment = false;
            boolean inMultiLineComment = false;
            boolean whiteSpace = false;
            int endPosition = 0;
            if (valuesEndPosition > 0) {
                valuesMode = true;
                endPosition = valuesEndPosition;
            } else {
                Matcher matcher = VALUES.matcher(sql);
                if (matcher.find()) {
                    valuesMode = true;
                    endPosition = matcher.end() - 1;
                }
            }
    
            int currentParensLevel = 0;
            int quotedStart = 0;
            int partStart = 0;
            int sqlLength = sql.length();
            for (int i = valuesMode ? endPosition : 0, idxStart = i, idxEnd = i; i < sqlLength; i++) {
                char c = sql.charAt(i);
                if (inSingleLineComment) {
                    if (c == '\n') {
                        inSingleLineComment = false;
                    }
                } else if (inMultiLineComment) {
                    if (c == '*' && sqlLength > i + 1 && sql.charAt(i + 1) == '/') {
                        inMultiLineComment = false;
                        i++;
                    }
                } else if (afterBackSlash) {
                    afterBackSlash = false;
                } else if (c == '\\') {
                    afterBackSlash = true;
                } else if (c == '\'' && !inBackQuotes) {
                    inQuotes = !inQuotes;
                    if (inQuotes) {
                        quotedStart = i;
                    } else if (!afterBackSlash) {
                        idxStart = quotedStart;
                        idxEnd = i + 1;
                    }
                } else if (c == '`' && !inQuotes) {
                    inBackQuotes = !inBackQuotes;
                } else if (!inQuotes && !inBackQuotes) {
                    if (c == '?') {
                        if (currentParensLevel > 0) {
                            idxStart = i;
                            idxEnd = i + 1;
                        }
                        if (!valuesMode) {
                            parts.add(sql.substring(partStart, i));
                            partStart = i + 1;
                            currentParamList.add(PARAM_MARKER);
                        }
                    } else if (c == '-' && sqlLength > i + 1 && sql.charAt(i + 1) == '-') {
                        inSingleLineComment = true;
                        i++;
                    } else if (c == '/' && sqlLength > i + 1 && sql.charAt(i + 1) == '*') {
                        inMultiLineComment = true;
                        i++;
                    } else if (c == ',') {
                        if (valuesMode && idxEnd > idxStart) {
                            currentParamList.add(typeTransformParameterValue(sql.substring(idxStart, idxEnd)));
                            parts.add(sql.substring(partStart, idxStart));
                            partStart = idxEnd;
                            idxEnd = i;
                            idxStart = idxEnd;
                        }
                        idxStart++;
                        idxEnd++;
                    } else if (c == '(') {
                        currentParensLevel++;
                        idxStart++;
                        idxEnd++;
                    } else if (c == ')') {
                        currentParensLevel--;
                        if (valuesMode && currentParensLevel == 0) {
                            if (idxEnd > idxStart) {
                                currentParamList.add(typeTransformParameterValue(sql.substring(idxStart, idxEnd)));
                                parts.add(sql.substring(partStart, idxStart));
                                partStart = idxEnd;
                                idxEnd = i;
                                idxStart = idxEnd;
                            }
                            if (!currentParamList.isEmpty()) {
                                parameters.add(currentParamList);
                                currentParamList = new ArrayList<>(currentParamList.size());
                            }
                        }
                    } else if (Character.isWhitespace(c)) {
                        whiteSpace = true;
                    } else if (currentParensLevel > 0) {
                        if (whiteSpace) {
                            idxStart = i;
                            idxEnd = i + 1;
                        } else {
                            idxEnd++;
                        }
                        whiteSpace = false;
                    }
                }
            }
            if (!valuesMode && !currentParamList.isEmpty()) {
                parameters.add(currentParamList);
            }
            String lastPart = sql.substring(partStart, sqlLength);
            parts.add(lastPart);
        }
    
        private static String typeTransformParameterValue(String paramValue) {
            if (paramValue == null) {
                return null;
            }
            if (Boolean.TRUE.toString().equalsIgnoreCase(paramValue)) {
                return "1";
            }
            if (Boolean.FALSE.toString().equalsIgnoreCase(paramValue)) {
                return "0";
            }
            if ("NULL".equalsIgnoreCase(paramValue)) {
                return NULL_MARKER;
            }
            return paramValue;
        }
    
        public String buildSql(List<PreparedStatementParameter> binds) {
            if (this.parts.size() == 1) {
                return this.parts.get(0);
            } else {
                StringBuilder sb = new StringBuilder(this.parts.get(0));
                int i = 1;
    
                for (int t = 0; i < this.parts.size(); ++i) {
                    String pValue = this.getParameter(i - 1);
                    //占位符-#{}会进行转义
                    if ("?".equals(pValue)) {
                        sb.append(binds.get(t++).getRegularValue());
                    } else {
                        sb.append(pValue);
                    }
                    sb.append(this.parts.get(i));
                }
                return sb.toString();
            }
        }
    
        private String getParameter(int paramIndex) {
            int i = 0;
            for (int count = paramIndex; i < this.parameters.size(); ++i) {
                List<String> pList = this.parameters.get(i);
                count = count - pList.size();
                if (count < 0) {
                    return pList.get(pList.size() + count);
                }
            }
            return null;
        }
    }
    

    文章参考

    Mybatis interceptor 获取clickhouse最终执行的sql

    【Mybatis】单独使用mybatis的SQL模板解析

    相关文章

      网友评论

          本文标题:【解析SQL模板-1】Mybatis的SQL模板组合成可运行的S

          本文链接:https://www.haomeiwen.com/subject/kivcljtx.html