SpringBoot2.x中HandlerInterceptor

作者: 小胖学编程 | 来源:发表于2021-11-12 10:35 被阅读0次

SpringBoot使用 AspectJ声明切面,实际上会将target对象加强为proxy对象。在AspectJ中可以定义@pointcut来声明切点。但是Interceptor可以也可以配置注解使用(骚套路)。

1. 原理

HandlerInterceptorAdapter的方法有一个handler参数。而handler可以进行强转HandlerMethod handlerMethod = (HandlerMethod) handler;然后获取到handlerMethod.getMethod()获取的是Controller中的方法,便可以解析上面的注解。然后根据注解值可以选择性的进行拦截。

public abstract class HandlerInterceptorAdapter implements AsyncHandlerInterceptor {

    /**
     * This implementation always returns {@code true}.
     */
    @Override
    public boolean preHandle(HttpServletRequest request, HttpServletResponse response, Object handler)
            throws Exception {

        return true;
    }

    /**
     * This implementation is empty.
     */
    @Override
    public void postHandle(HttpServletRequest request, HttpServletResponse response, Object handler,
            @Nullable ModelAndView modelAndView) throws Exception {
    }

    /**
     * This implementation is empty.
     */
    @Override
    public void afterCompletion(HttpServletRequest request, HttpServletResponse response, Object handler,
            @Nullable Exception ex) throws Exception {
    }

    /**
     * This implementation is empty.
     */
    @Override
    public void afterConcurrentHandlingStarted(HttpServletRequest request, HttpServletResponse response,
            Object handler) throws Exception {
    }

}

2. 实现

2.1 优化AnnotationUtils方法

因为org.springframework.core.annotation.AnnotationUtils去解析属性上的注解,这个方法是一个相对耗时且调用频繁的函数(反射)。虽然SpringMVC会缓存该方法的结果,但是若返回null时,不进行缓存。所以需要在外部包一层null值的缓存。

相关依赖:

<dependency>
  <groupId>com.google.guava</groupId>
  <artifactId>guava</artifactId>
  <version>31.0.1-jre</version>
</dependency>

将null值缓存起来:

import java.lang.annotation.Annotation;
import java.util.Optional;

import org.springframework.core.annotation.AnnotationUtils;
import org.springframework.web.method.HandlerMethod;

import com.google.common.base.Objects;
import com.google.common.cache.CacheBuilder;
import com.google.common.cache.CacheLoader;
import com.google.common.cache.LoadingCache;

public class MethodAnnotationCacheUtils {
    //缓存的大小
    private static final int MAX_CACHE_SIZE = 5000;
    private static final int CACHE_CONCURRENCY_LEVEL = 200;

    /**
     * 本缓存对象可以缓存null值,可以优化性能。
     */
    private static final LoadingCache<AnnotationCacheKey, Optional<Annotation>> ANNOTATION_CACHE =
            CacheBuilder.newBuilder().weakKeys().maximumSize(MAX_CACHE_SIZE).concurrencyLevel(CACHE_CONCURRENCY_LEVEL)
                    .build(new CacheLoader<AnnotationCacheKey, Optional<Annotation>>() {
                        @Override
                        public Optional<Annotation> load(AnnotationCacheKey key) throws Exception {
                            return Optional.ofNullable(
                                    getAnnotationInternal(key.getHandler(), key.getAnnotationClass()));
                        }
                    });


    public static <A extends Annotation> A getAnnotation(Object handler, Class<A> annotationClass) {
        if (!(handler instanceof HandlerMethod)) {
            return null;
        }
        //key对象
        AnnotationCacheKey cacheKey = new AnnotationCacheKey(handler, annotationClass);
        return (A) ANNOTATION_CACHE.getUnchecked(cacheKey).orElse(null);
    }


    /**
     * 获取handler上的对应的注解信息,可能返回null。
     * 缓存null值。
     */
    private static <A extends Annotation> A getAnnotationInternal(Object handler, Class<A> annotationClass) {
        if (!(handler instanceof HandlerMethod)) {
            return null;
        }
        HandlerMethod handlerMethod = (HandlerMethod) handler;
        //获取注解上的值
        A result = AnnotationUtils.findAnnotation(handlerMethod.getMethod(), annotationClass);
        if (result == null) {
            result = AnnotationUtils.findAnnotation(handlerMethod.getBeanType(), annotationClass);
        }
        return result;
    }

    /**
     * 注解缓存的key
     */
    private static class AnnotationCacheKey {

        private final Object handler;

        private final Class<? extends Annotation> annotationClass;

        public AnnotationCacheKey(Object handler, Class<? extends Annotation> annotationClass) {
            this.handler = handler;
            this.annotationClass = annotationClass;
        }

        public Object getHandler() {
            return handler;
        }

        public Class<? extends Annotation> getAnnotationClass() {
            return annotationClass;
        }

        //guava生成的方法
        @Override
        public boolean equals(Object o) {
            if (this == o) {
                return true;
            }
            if (o == null || getClass() != o.getClass()) {
                return false;
            }
            AnnotationCacheKey that = (AnnotationCacheKey) o;
            return Objects.equal(handler, that.handler) && Objects.equal(
                    annotationClass, that.annotationClass);
        }

        @Override
        public int hashCode() {
            return Objects.hashCode(handler, annotationClass);
        }
    }
}

获取@RequestParam注解的值:

public class MethodInterceptorUtils {

    public static <A extends Annotation> A getAnnotation(Object handler, Class<A> annotationClass) {
        return MethodAnnotationCacheUtils.getAnnotation(handler, annotationClass);
    }

    public static Stream<String> getRequestParams(Object handler) {
        if (!(handler instanceof HandlerMethod)) {
            return null;
        }
        HandlerMethod handlerMethod = (HandlerMethod) handler;
        MethodParameter[] methodParameters = handlerMethod.getMethodParameters();
        return Stream.of(methodParameters)
                .map(MethodInterceptorUtils::getParamName)
                .filter(Objects::nonNull);
    }


    /**
     * 解析参数
     */
    private static String getParamName(MethodParameter methodParameter) {
        RequestParam requestParam = methodParameter.getParameterAnnotation(RequestParam.class);
        if (requestParam != null) {
            return requestParam.value();
        }
        return null;
    }
}

2.2 个性化注解

创建个性化注解:

@Target({ElementType.TYPE, ElementType.METHOD})
@Retention(RetentionPolicy.RUNTIME)
public @interface ThreadNamePattern {

    /**
     * 声明要打印输出的参数值。
     */
    String[] value() default {};

    /**
     * 返回true时只使用value的参数记录,false时value会追加到@RequestParam的参数。
     */
    boolean overrideDeclare() default false;
}

创建拦截器:

@Slf4j
@Service
public class ThreadNamePatternInterceptor extends HandlerInterceptorAdapter {

    /**
     * 忽略的参数
     */
    private Set<String> ignoreParamNames = ImmutableSet.of("password");

    /**
     * 线程开始,读取Controller方法中的注解信息,在拦截器中进行增强。
     */
    @Override
    public boolean preHandle(HttpServletRequest request, HttpServletResponse response, Object handler)
            throws Exception {
        ThreadNamePattern threadNamePattern = MethodInterceptorUtils.getAnnotation(handler, ThreadNamePattern.class);
        setThreadName(handler, threadNamePattern, request);
        return super.preHandle(request, response, handler);
    }

    private void setThreadName(Object handler, ThreadNamePattern threadNamePattern,
            HttpServletRequest request) {
        try {
            //获取当前线程
            Thread currentThread = Thread.currentThread();
            String originalThreadName = currentThread.getName();

            if (!originalThreadName.endsWith("}")) {
                StringBuffer sb = new StringBuffer(request.getRequestURI());
                Stream<String> paramsNames;
                if (threadNamePattern.overrideDeclare()) {
                    //只输出value的参数
                    paramsNames = Stream.of(threadNamePattern.value());
                } else {
                    //输出所有的@RequestParam+value的参数
                    paramsNames = Stream.concat(Stream.of(threadNamePattern.value()),
                            MethodInterceptorUtils.getRequestParams(handler));
                }
                //拼接参数
                String param = paramsNames.distinct().filter(name -> !ignoreParamNames.contains(name))
                        .filter(name -> request.getParameter(name) != null)
                        .map(name -> name + "=" + request.getParameter(name)).collect(
                                Collectors.joining("&"));
                if (!param.isEmpty()) {
                    sb.append("?").append(param);
                }
                //拼接线程信息
                String newThreadName = originalThreadName + "-{" + sb + "}";
                currentThread.setName(newThreadName);
            }
        } catch (Exception e) {
            log.error("", e);
        }
    }


    /**
     * 线程结束,清除个性化线程信息。
     */
    @Override
    public void afterCompletion(HttpServletRequest request, HttpServletResponse response, Object handler, Exception ex)
            throws Exception {
        try {
            super.afterCompletion(request, response, handler, ex);
        } finally {
            ThreadNamePattern threadNamePattern =
                    MethodInterceptorUtils.getAnnotation(handler, ThreadNamePattern.class);
            if (threadNamePattern != null) {
                tryRestoreThreadName();
            }
        }
    }

    private void tryRestoreThreadName() {
        try {
            Thread currentThread = currentThread();
            String originalThreadName = currentThread.getName();
            if (originalThreadName.endsWith("}")) {
                String newThreadName = originalThreadName.substring(0,
                        originalThreadName.lastIndexOf("-{"));
                currentThread.setName(newThreadName);
            }
        } catch (Throwable e) {
            log.error("", e);
        }
    }
}

3. 测试方法

@Slf4j
@RestController
public class FirstController {

    @RequestMapping(value = "/test")
    @ThreadNamePattern(overrideDeclare = true, value = "id")
    public String test(@RequestParam("id") Long id, @RequestParam("name") String name) {
        log.info("test,请求进来了");
        return id + "-" + name + " is success";
    }

    @RequestMapping(value = "/t1")
    @ThreadNamePattern(overrideDeclare = false, value = "id")
    public String t1(@RequestParam("id") Long id, @RequestParam("name") String name) {
        log.info("t1,请求进来了");
        return id + "-" + name + " is success";
    }

}

相关文章

网友评论

    本文标题:SpringBoot2.x中HandlerInterceptor

    本文链接:https://www.haomeiwen.com/subject/bdibzltx.html