Mybatis 是一个常用 ORM 持久层框架。其在动态 SQL 执行的基础功能上,提供了插件的功能。Mybatis 插件通过切面编程的方式,无侵入的进行功能扩展。最常见的插件有分页插件(PageHelper)。
Mybatis Plugin 源码分析
mybatis 插件的核心功能代码主要在 "org.apache.ibatis.plugin" 包下,主要包括以下类。
- 注解
- Intercepts
// The annotation that specify target methods to intercept. // 指定需要拦截的方法的注解 @Documented @Retention(RetentionPolicy.RUNTIME) @Target(ElementType.TYPE) public @interface Intercepts { /** * Returns method signatures to intercept. * * @return method signatures */ Signature[] value(); }
- Signature
/** * The annotation that indicate the method signature. * 声明被拦截方法的基本信息:类名、方法名、参数列表 */ @Documented @Retention(RetentionPolicy.RUNTIME) @Target({}) public @interface Signature { Class<?> type(); String method(); Class<?>[] args(); }
- Intercepts
- 接口
- Interceptor
Mybatis 插件中拦截器的公用接口,自定义插件时需要实现该接口的 intercept 方法。public interface Interceptor { Object intercept(Invocation invocation) throws Throwable; default Object plugin(Object target) { return Plugin.wrap(target, this); } default void setProperties(Properties properties) { // NOP } }
- Interceptor
- 功能类
-
Invocation
Invocation 类中属性与 Signature 注解中的属性一一对应,同时定义 proceed 方法通过反射的方式执行被代理类的方法。public class Invocation { private final Object target; private final Method method; private final Object[] args; public Invocation(Object target, Method method, Object[] args) { this.target = target; this.method = method; this.args = args; } public Object getTarget() { return target; } public Method getMethod() { return method; } public Object[] getArgs() { return args; } public Object proceed() throws InvocationTargetException, IllegalAccessException { return method.invoke(target, args); } }
-
Plugin
Plugin 类实现 Java 动态代理需要的 InvocationHandler 类,在 wrap 方法返回对应的动态代理类。public class Plugin implements InvocationHandler { private final Object target; private final Interceptor interceptor; private final Map<Class<?>, Set<Method>> signatureMap; private Plugin(Object target, Interceptor interceptor, Map<Class<?>, Set<Method>> signatureMap) { this.target = target; this.interceptor = interceptor; this.signatureMap = signatureMap; } // 当 target 类需要被拦截时,生成动态代理类 public static Object wrap(Object target, Interceptor interceptor) { Map<Class<?>, Set<Method>> signatureMap = getSignatureMap(interceptor); Class<?> type = target.getClass(); Class<?>[] interfaces = getAllInterfaces(type, signatureMap); if (interfaces.length > 0) { return Proxy.newProxyInstance( type.getClassLoader(), interfaces, new Plugin(target, interceptor, signatureMap)); } return target; } // 确定 method 是否会被 Interceptor 拦截,如果是就执行拦截器中intercept方法 @Override public Object invoke(Object proxy, Method method, Object[] args) throws Throwable { try { Set<Method> methods = signatureMap.get(method.getDeclaringClass()); if (methods != null && methods.contains(method)) { return interceptor.intercept(new Invocation(target, method, args)); } return method.invoke(target, args); } catch (Exception e) { throw ExceptionUtil.unwrapThrowable(e); } } // 获取自定义的 Interceptor 类 Interceptors 注解中的 Signature 注解定义的方法 private static Map<Class<?>, Set<Method>> getSignatureMap(Interceptor interceptor) { Intercepts interceptsAnnotation = interceptor.getClass().getAnnotation(Intercepts.class); // issue #251 if (interceptsAnnotation == null) { throw new PluginException("No @Intercepts annotation was found in interceptor " + interceptor.getClass().getName()); } Signature[] sigs = interceptsAnnotation.value(); Map<Class<?>, Set<Method>> signatureMap = new HashMap<>(); for (Signature sig : sigs) { Set<Method> methods = signatureMap.computeIfAbsent(sig.type(), k -> new HashSet<>()); try { Method method = sig.type().getMethod(sig.method(), sig.args()); methods.add(method); } catch (NoSuchMethodException e) { throw new PluginException("Could not find method on " + sig.type() + " named " + sig.method() + ". Cause: " + e, e); } } return signatureMap; } private static Class<?>[] getAllInterfaces(Class<?> type, Map<Class<?>, Set<Method>> signatureMap) { Set<Class<?>> interfaces = new HashSet<>(); while (type != null) { for (Class<?> c : type.getInterfaces()) { if (signatureMap.containsKey(c)) { interfaces.add(c); } } type = type.getSuperclass(); } return interfaces.toArray(new Class<?>[0]); } }
-
InterceptorChain
Configuration 类中会包含 InterceptorChain 实例, 并在 SqlSessionFactoryBean 类中初始化。public class InterceptorChain { private final List<Interceptor> interceptors = new ArrayList<>(); public Object pluginAll(Object target) { for (Interceptor interceptor : interceptors) { target = interceptor.plugin(target); } return target; } public void addInterceptor(Interceptor interceptor) { interceptors.add(interceptor); } public List<Interceptor> getInterceptors() { return Collections.unmodifiableList(interceptors); } }
-
PluginException
-
Mybatis Plugin 调用链
Plugin 类通过 wrap 方法生成对应的代理类, wrap 方法的调用链如下:
Configuration
-> newParameterHandler
-> interceptorChain.pluginAll
-> interceptor.plugin
-> Plugin.wrap(target, this)
-> newStatementHandler
-> interceptorChain.pluginAll
-> interceptor.plugin
-> Plugin.wrap(target, this)
-> newExecutor
-> interceptorChain.pluginAll
-> interceptor.plugin
-> Plugin.wrap(target, this)
-> newResultSetHandler
-> interceptorChain.pluginAll
-> interceptor.plugin
-> Plugin.wrap(target, this)
从整条调用链可以看出 Mybatis 的插件可以增强的功能类有四个:
1. ParameterHandler
2. StatementHandler
3. Executor
4. ResultSetHandler
而这四个类正是 Mybatis 的四大组件,四个类共同完成 JDBC 的相关操作。
Mybatis 四大组件
-
ParameterHandler
负责将用户传递的参数转换为 JDBC Statement 需要的参数 -
StatementHandler
封装 JDBC Statement 相关的操作,比如参数的设置、将 Statement 的结果集转换为 List -
ResultSetHandler
负责将 JDBC 返回的 ResultSet 结果集转换为 List 类型集合 -
Executor
MyBatis执行器,是MyBatis 调度的核心,负责SQL语句的生成和查询缓存的维护
Mybatis 的核心架构图如下:

Mybatis Plugin 应用实例
了解了 Mybatis Plugin 源码及相关组件功能后, 通过 “执行SQL打印” 的扩展功能再熟悉 Mybatis Plugin 的使用。 完成一个 Mybatis Plugin 的核心是构建一个类实现Interceptor 接口,对于 Spring 应用来说将该实现类注入至IOC容器即可。
@Slf4j
@Intercepts({
@Signature(type = Executor.class, method = "update", args = {MappedStatement.class, Object.class}),
@Signature(type = Executor.class, method = "query", args = {MappedStatement.class, Object.class,
RowBounds.class, ResultHandler.class})})
public class PrintSqlInterceptor implements Interceptor {
@Override
public Object intercept(Invocation invocation) throws Throwable {
MappedStatement mappedStatement = (MappedStatement) invocation.getArgs()[0];
Object parameter = null;
if (invocation.getArgs().length > 1) {
parameter = invocation.getArgs()[1];
}
String sqlId = mappedStatement.getId();
BoundSql boundSql = mappedStatement.getBoundSql(parameter);
Configuration configuration = mappedStatement.getConfiguration();
long start = System.currentTimeMillis();
Object returnValue = invocation.proceed();
long time = System.currentTimeMillis() - start;
printSql(configuration, boundSql, time, sqlId);
return returnValue;
}
private static void printSql(Configuration configuration, BoundSql boundSql, long time, String sqlId) {
Object parameterObject = boundSql.getParameterObject();
List<ParameterMapping> parameterMappings = boundSql.getParameterMappings();
//替换空格、换行、tab缩进等
String sql = boundSql.getSql().replaceAll("[\\s]+", " ");
if (parameterMappings.size() > 0 && parameterObject != null) {
TypeHandlerRegistry typeHandlerRegistry = configuration.getTypeHandlerRegistry();
if (typeHandlerRegistry.hasTypeHandler(parameterObject.getClass())) {
sql = sql.replaceFirst("\\?", getParameterValue(parameterObject));
} else {
MetaObject metaObject = configuration.newMetaObject(parameterObject);
for (ParameterMapping parameterMapping : parameterMappings) {
String propertyName = parameterMapping.getProperty();
if (metaObject.hasGetter(propertyName)) {
Object obj = metaObject.getValue(propertyName);
sql = sql.replaceFirst("\\?", getParameterValue(obj));
} else if (boundSql.hasAdditionalParameter(propertyName)) {
Object obj = boundSql.getAdditionalParameter(propertyName);
sql = sql.replaceFirst("\\?", getParameterValue(obj));
}
}
}
}
logs(time, sql, sqlId);
}
private static String getParameterValue(Object obj) {
String value;
if (obj instanceof String) {
value = "'" + obj.toString() + "'";
} else if (obj instanceof Date) {
DateFormat formatter = DateFormat.getDateTimeInstance(DateFormat.DEFAULT, DateFormat.DEFAULT, Locale.CHINA);
value = "'" + formatter.format(obj) + "'";
} else {
if (obj != null) {
value = obj.toString();
} else {
value = "";
}
}
return value.replace("$", "\\$");
}
private static void logs(long time, String sql, String sqlId) {
StringBuilder sb = new StringBuilder()
.append(" Time:").append(time)
.append(" ms - ID:").append(sqlId)
.append(StringPool.NEWLINE).append("Execute SQL:")
.append(sql).append(StringPool.NEWLINE);
log.info(sb.toString());
}
@Override
public Object plugin(Object target) {
return Plugin.wrap(target, this);
}
@Override
public void setProperties(Properties properties0) {
}
}
网友评论