美文网首页
Spring 实现一个支持 分发 转发的 rest接口

Spring 实现一个支持 分发 转发的 rest接口

作者: Yellowtail | 来源:发表于2019-02-16 10:03 被阅读0次

    需求概述

    我们的需求是这样的:
    后台的rest接口对接 安卓、IOS客户端、小程序
    考虑到安全,我们对请求参数做了签名校验,对部分接口做了登录态校验
    安全是安全了,但是 对于我们后台开发人员来说,进行接口测试就变得困难起来

    接口测试,我们前面是这么做的

    阶段一

    在本地进行接口测试,把关键的Filter 注释掉
    感受:太麻烦了,不小心就会把注释的代码提交了(我用的是乌龟git)

    阶段二

    我用Python写了一个本地代理,postman 配置好代理,让请求走我的代理
    这个代理会模拟客户端,对参数进行签名
    工程在这里:proxy
    感受:我们有些接口使用到了Protobuf,因为传输的是二进制数据流,接口返回的数据,代理想转为可视化的json数据很费劲
    (目前是先用protobuf生成Python文件,再反序列化,再显示,生成文件这一步需要经常维护)
    而且我们的登录态用到了jwttoken动不动就过期了,需要更换;虽然使用postman的全局变量可以减少更换次数,但是如果想切换用户查看接口返回结果,又需要去数据库找到对应的token,再替换,有点费劲

    经历过这两个阶段,后面就想着,能不能开发出这样一个接口呢:
    参数里写上url method userId
    然后这个接口根据这些参数,进行分发(转发),然后在Filter里面排除这个接口,不就避免了签名校验 登陆校验吗?

    如何实现

    刚开始实现的时候,也是一脸懵逼,用“转发” 关键字 搜了一番,发现都不能用在Rest 接口
    最后没办法,想着去了解一下spring是怎么实现的,我再重复实现一遍不就OK了吗
    于是去看了看 DispatcherServlet 源码解析文章,找到了思路

    先看看 DispatcherServlet 的 核心方法 doDispatch

    protected void doDispatch(HttpServletRequest request, HttpServletResponse response) throws Exception {
            HttpServletRequest processedRequest = request;
            HandlerExecutionChain mappedHandler = null;
            boolean multipartRequestParsed = false;
    
            WebAsyncManager asyncManager = WebAsyncUtils.getAsyncManager(request);
    
            try {
                ModelAndView mv = null;
                Exception dispatchException = null;
    
                try {
                    processedRequest = checkMultipart(request);
                    multipartRequestParsed = (processedRequest != request);
    
                    // Determine handler for the current request.
                                    // 找到当前request对应的 handler
                    mappedHandler = getHandler(processedRequest);
                    if (mappedHandler == null || mappedHandler.getHandler() == null) {
                        noHandlerFound(processedRequest, response);
                        return;
                    }
    
                    // Determine handler adapter for the current request.
                                    // 找到当前 handler 对应的 适配器
                    HandlerAdapter ha = getHandlerAdapter(mappedHandler.getHandler());
    
                    // Process last-modified header, if supported by the handler.
                    String method = request.getMethod();
                    boolean isGet = "GET".equals(method);
                    if (isGet || "HEAD".equals(method)) {
                        long lastModified = ha.getLastModified(request, mappedHandler.getHandler());
                        if (logger.isDebugEnabled()) {
                            logger.debug("Last-Modified value for [" + getRequestUri(request) + "] is: " + lastModified);
                        }
                        if (new ServletWebRequest(request, response).checkNotModified(lastModified) && isGet) {
                            return;
                        }
                    }
    
                                    //执行 HandlerExecutionChain 里面 拦截器的 preHandle
                    if (!mappedHandler.applyPreHandle(processedRequest, response)) {
                        return;
                    }
    
                    // Actually invoke the handler.
                                    // 让适配器运行handler,也就是执行 Controller里的某个具体的方法
                    mv = ha.handle(processedRequest, response, mappedHandler.getHandler());
    
                    if (asyncManager.isConcurrentHandlingStarted()) {
                        return;
                    }
    
                    applyDefaultViewName(processedRequest, mv);
    
                                    //执行 HandlerExecutionChain 里面 拦截器的 postHandle
                    mappedHandler.applyPostHandle(processedRequest, response, mv);
                }
                catch (Exception ex) {
                    dispatchException = ex;
                }
                catch (Throwable err) {
                    // As of 4.3, we're processing Errors thrown from handler methods as well,
                    // making them available for @ExceptionHandler methods and other scenarios.
                    dispatchException = new NestedServletException("Handler dispatch failed", err);
                }
                processDispatchResult(processedRequest, response, mappedHandler, mv, dispatchException);
            }
            catch (Exception ex) {
                triggerAfterCompletion(processedRequest, response, mappedHandler, ex);
            }
            catch (Throwable err) {
                triggerAfterCompletion(processedRequest, response, mappedHandler,
                        new NestedServletException("Handler processing failed", err));
            }
            finally {
                if (asyncManager.isConcurrentHandlingStarted()) {
                    // Instead of postHandle and afterCompletion
                    if (mappedHandler != null) {
                        mappedHandler.applyAfterConcurrentHandlingStarted(processedRequest, response);
                    }
                }
                else {
                    // Clean up any resources used by a multipart request.
                    if (multipartRequestParsed) {
                        cleanupMultipart(processedRequest);
                    }
                }
            }
        }
    

    关键的地方我都写了中文注释

    所以 我们自己实现 分发 效果的时候,参考这个逻辑即可
    找到handler --> 找到适配器 --> 执行拦截器 preHandle(如果需要) --> 执行 handler --> 执行拦截器 postHandle(如果需要)

    实现代码

    下面是 Controller

    package com.xxx.app.skmr.controller;
    
    import java.util.HashMap;
    
    import javax.servlet.http.HttpServletRequest;
    import javax.servlet.http.HttpServletResponse;
    
    import org.apache.commons.lang3.StringUtils;
    import org.powermock.reflect.Whitebox;
    import org.slf4j.Logger;
    import org.slf4j.LoggerFactory;
    import org.springframework.beans.factory.annotation.Autowired;
    import org.springframework.http.HttpHeaders;
    import org.springframework.http.MediaType;
    import org.springframework.web.bind.annotation.RequestMapping;
    import org.springframework.web.bind.annotation.RequestMethod;
    import org.springframework.web.bind.annotation.RequestParam;
    import org.springframework.web.bind.annotation.RestController;
    import org.springframework.web.method.HandlerMethod;
    import org.springframework.web.servlet.DispatcherServlet;
    import org.springframework.web.servlet.HandlerAdapter;
    import org.springframework.web.servlet.HandlerExecutionChain;
    import org.springframework.web.servlet.ModelAndView;
    
    import com.xxx.app.skmr.constant.Constants;
    import com.xxx.app.skmr.filter.EncryptRequest;
    import com.xxx.app.skmr.properties.CommonConfigProperties;
    import com.xxx.app.skmr.service.ReqCacheService;
    import com.xxx.app.skmr.util.AssertHelper;
    
    @RequestMapping("/v1")
    @RestController
    public class InterfaceTestController {
        
        private static final Logger LOGGER = LoggerFactory.getLogger(InterfaceTestController.class);
        
        @Autowired
        private DispatcherServlet dispatcherServlet;
        
        @Autowired
        private CommonConfigProperties commonConfigProperties;
    
        @RequestMapping(value = "/iTest", method = RequestMethod.POST)
        public void receiveRequest(
                @RequestParam(value="url", required=true) String url,
                @RequestParam(value="method", required=true) String method,
                @RequestParam(value="param", required=false) String param,
                @RequestParam(value="userId", required=false) String userId,
                HttpServletRequest request, 
                HttpServletResponse response) {
            
            LOGGER.info("InterfaceTestController receiveRequest, url={}", url);
            
            if (! Constants.SKMR_APP_ENV_NAME_DEV.equals(commonConfigProperties.getEnvName())) {
                //不是 dev 环境,直接退出
                return;
            }
            
            //设置登录态
            if (StringUtils.isNotBlank(userId)) {
                ReqCacheService.setReqUserId(request, userId);
            }
            
            //改变request里的值
            EncryptRequest myRequest = (EncryptRequest) request;
            
            //强行设置 header 里的 accept 为 application/json,为客户端省点事
            myRequest.setHeader(HttpHeaders.ACCEPT, MediaType.APPLICATION_JSON_VALUE);
            
            myRequest.setHeader(HttpHeaders.CONTENT_TYPE, MediaType.APPLICATION_JSON_VALUE);
            
            //修改方法,方法要大写
            myRequest.setMethod(method.toUpperCase());
            
            //修改 请求path
            myRequest.setRequestURI(url);
            
            //这个必须要设置,不然在 UrlPathHelper.getLookupPathForRequest()  方法里面 有问题,需要让rest 为空
            myRequest.setServletPath(url);
            
            //设置url 查询参数,
            if (StringUtils.isNotBlank(param)) {
                HashMap<String, String> convertParam = convertParam(param);
                
                myRequest.setUseExternalParam(true);
                myRequest.setParamMap(convertParam);
            }
            
            //转发
            redirect(myRequest, response);
            
            return ;
        }
        
        @RequestMapping(value = "/iTest", method = RequestMethod.GET)
        public void receiveRequestV2(
                @RequestParam(value="url", required=true) String url,
                @RequestParam(value="method", required=true) String method,
                @RequestParam(value="param", required=false) String param,
                @RequestParam(value="userId", required=false) String userId,
                HttpServletRequest request, 
                HttpServletResponse response) {
            
            //这个接口因为是 get 的,所以可以通过浏览器直接调用,不需要postman
            //当然了,对于有body的,浏览器就不行了,还是需要postman
            
            LOGGER.info("InterfaceTestController receiveRequestV2, url={}", url);
            
            receiveRequest(url, method, param, userId, request, response);
            
            return ;
        }
        
        /**
         * <br>转发 分发
         * <br>只有自身的url 会走 过滤器,转发后的 没有走过滤器
         *
         * @param request
         * @param response
         * @author YellowTail
         * @since 2019-02-15
         */
        private void redirect(HttpServletRequest request, HttpServletResponse response) {
            
            try {
                // 1. 得到 HandlerExecutionChain, 调用方法 getHandler 即可得到
                HandlerExecutionChain handlerExecutionChain = Whitebox.invokeMethod(dispatcherServlet, "getHandler", request);
                
                // 2. 取出 HandlerMethod,适配器要用
                HandlerMethod handlerMethod = (HandlerMethod) handlerExecutionChain.getHandler();
                
                // 3. 得到 适配器 HandlerAdapter,调用方法 getHandlerAdapter 得到
                HandlerAdapter ha = Whitebox.invokeMethod(dispatcherServlet, "getHandlerAdapter", handlerMethod);
                
                // 4. 执行 HandlerExecutionChain 拦截器的 preHandler() 前置方法, CmdHandlerInterceptor 会去设置 cmd
                Whitebox.invokeMethod(handlerExecutionChain, "applyPreHandle", request, response);
                
                // 5. 执行 handler
                ModelAndView mv = ha.handle(request, response, handlerMethod);
                
                // 6. 执行 拦截器的 postHandler() 方法, LogHandlerInterceptor 会去记录日志
                Whitebox.invokeMethod(handlerExecutionChain, "applyPostHandle", request, response, mv);
                
            } catch (Exception e) {
                LOGGER.error("error, ", e);
            }
        }
        
        /**
         * <br>将字符串形式的 参数  id=test&type=2 转换为 map,方便使用
         *
         * @param param
         * @return
         * @author YellowTail
         * @since 2019-02-15
         */
        private HashMap<String, String> convertParam(String param) {
            HashMap<String, String> map = new HashMap<>();
            
            if (StringUtils.isBlank(param)) {
                return map;
            }
            
            String[] split = param.split("&");
            
            for(String eachParam : split) {
                String[] split2 = eachParam.split("=");
                
                AssertHelper.assertTrue(2 >= split2.length , eachParam + " eachParam should contains one =");
                
                if (2 == split2.length) {
                    map.put(split2[0], split2[1]);
                }
            }
            
            return map;
        }
    }
    
    

    因为我们需要对 request 做很多操作,所以必须自己实现一个request ,且继承 HttpServletRequestWrapper

    package com.xxx.app.skmr.filter;
    
    import java.io.BufferedReader;
    import java.io.ByteArrayInputStream;
    import java.io.IOException;
    import java.io.StringReader;
    import java.util.Enumeration;
    import java.util.HashMap;
    import java.util.Vector;
    
    import javax.servlet.ReadListener;
    import javax.servlet.ServletInputStream;
    import javax.servlet.http.HttpServletRequest;
    import javax.servlet.http.HttpServletRequestWrapper;
    
    import org.slf4j.Logger;
    import org.slf4j.LoggerFactory;
    
    import com.xxx.app.skmr.util.IOSystem;
    import com.xxx.app.skmr.util.StringUtils;
    
    
    public class EncryptRequest extends HttpServletRequestWrapper {
        
        private static final Logger LOGGER = LoggerFactory.getLogger(EncryptRequest.class);
        
        /**
         * URL 的方法
         */
        private String _method;
        
        /**
         * URI
         */
        private String _requestURI;
        
        /**
         * ServletPath
         */
        private String _servletPath;
        
        
        /**
         * 是否使用 扩展的 参数
         */
        private boolean useExternalParam = false;
        
        /**
         * 参数 map
         */
        private HashMap<String, String> paramMap;
        
        /**
         * header
         */
        private HashMap<String, String> _headers;
        
        /**
         * 存储requestBody的内容
         */
        private byte[] requestBody = null;
        
    
       public EncryptRequest(HttpServletRequest request) {
           super(request);
      
           try {
               this.requestBody = IOSystem.readToBytes(request.getInputStream());
           } catch (IOException e) {
               e.printStackTrace();
               LOGGER.error("EncryptRequest init getInputStream error", e);
               throw new RuntimeException(e);
           }
       }
    
        /**
         * 获取requestbody
         */
        public byte[] getRequestBody() {
            return this.requestBody;
        }
        
        public String getRequestBodyString() {
            return StringUtils.toString(requestBody, this.getRequest().getCharacterEncoding());
        }
        
        public void setRequestBody(byte[] requestBody ) {
            this.requestBody = requestBody;
        }
    
    
        @Override
        public ServletInputStream getInputStream() throws IOException {
            // 如果是null 证明首次数据获取失败
            if (requestBody == null) {
                requestBody = new byte[0];
            }
            
            final ByteArrayInputStream bais = new ByteArrayInputStream(requestBody);
            return new ServletInputStream() {
                @Override
                public int read() throws IOException {
                    return bais.read();
                }
    
                @Override
                public boolean isFinished() {
                    return false;
                }
    
                @Override
                public boolean isReady() {
                    return true;
                }
    
                @Override
                public void setReadListener(ReadListener listener) {
                }
            };
        }
    
    
        @Override
        public BufferedReader getReader() throws IOException {
            return new BufferedReader(new StringReader(this.getRequestBodyString()));
        }
        
        /**
         * 重写 getMethod,方便 变更 method
         */
        @Override
        public String getMethod() {
            if (null == _method) {
                _method = super.getMethod();
            }
            return _method;
        }
        
        /**
         * <br>扩展方法:支持修改  method
         * <br>在 接口测试的接口里使用到了
         *
         * @param newMethod
         * @author YellowTail
         * @since 2019-02-14
         */
        public void setMethod(String newMethod) {
            _method = newMethod;
        }
        
        /**
         * 覆盖,方便变更
         */
        @Override
        public String getRequestURI() {
            if (null == _requestURI) {
                _requestURI = super.getRequestURI();
            }
            return _requestURI;
        }
        
        /**
         * <br>扩展方法,支持变更 path
         * <br>在 接口测试的接口里使用到了
         *
         * @param value
         * @author YellowTail
         * @since 2019-02-14
         */
        public void setRequestURI(String value) {
            _requestURI = value;
        }
        
        @Override
        public String getServletPath() {
            if (null == _servletPath) {
                _servletPath = super.getServletPath();
            }
            return _servletPath;
        }
        
        public void setServletPath(String value) {
            _servletPath = value;
        }
        
        @Override
        public String[] getParameterValues(String name) {
            if (useExternalParam) {
                String string = paramMap.get(name);
                if (null == string) {
                    return null;
                }
                
                return new  String [] {string};
            }
            return super.getParameterValues(name);
        }
    
        public void setUseExternalParam(boolean useExternalParam) {
            this.useExternalParam = useExternalParam;
        }
    
        public void setParamMap(HashMap<String, String> paramMap) {
            this.paramMap = paramMap;
        }
        
        
        /**
         * 覆盖 header 获取的方法,进行自定义扩展
         */
        @Override
        public Enumeration<String> getHeaders(String name) {
            
            //如果某个值被设置过,那么用自定义的值
            if (null != _headers && _headers.containsKey(name)) {
                String string = _headers.get(name);
                
                Vector<String> values = new Vector<String>();
                values.add(string);
                
                return values.elements();
            }
            
            return super.getHeaders(name);
        }
        
        /**
         * <br>设置 Header 里的值
         *
         * @param key
         * @param value
         * @author YellowTail
         * @since 2019-02-15
         */
        public void setHeader(String key, String value) {
            if (null == _headers) {
                _headers = new HashMap<>();
            }
            
            _headers.put(key, value);
        }
       
    }
    

    实现代码解释

    因为 DispatcherServlet 的很多方法都是 protected, friendly 懒的搞继承,直接反射调用了
    关于拦截器,因为我在实现的时候,是需要拦截器的一些效果(自动上传日志),所以就执行了步骤 4 和 6
    大家在实现的时候,可以根据实际情况进行取舍

    为何自定义request
    因为 HttpServletRequest 只有一堆的 get 方法,没有 set 方法
    看了下实现,反射好费劲,算了,直接继承一个,复写方法

    效果

    Type value
    接口地址 /v1/iTest
    接口方法 Post
    接口Header header里的Accept Content-Type 都不需要设置,代码已经写死为application/json
    设置为其他值不会生效
    接口参数 是否必填 解释
    url 必填 准备请求哪个接口
    method 必填 接口的方法(因为有些接口url一样,但是method不一样),
    大小写不敏感, get Get GET 都行
    userId 非必填 设置登录态,即想用哪个用户请求接口
    param 非必填 请求接口的请求参数,比如对于接口 /v1/xxx/me/list?unitId=2&nextid=&scope=2
    那么param就是?后面的字符串,且需要进行url编码
    也就是unitId%3D2%26nextid%3D%26scope%3D2

    BUG 修复 2019年2月20日 10:31:35

    修复了 param 里面 参数值为空抛异常的问题
    代码已在此博客里更新

    功能新增:把异常输出到浏览器上,省去看日志的步骤

    一旦代码抛了异常,浏览器调用接口的时候,看不到信息
    于是突发奇想,把 Exception 信息 参考 Logger 那种方式输出到屏幕上,多方便
    于是写了一下,代码在下面,没有更新到 文章开始的那个代码块里

    public static final String CHANGE_LINE = "\n".intern();
    public static final String TAB_INDENT = "    at ";
    
    ...
    catch (Exception e) {
                LOGGER.error("error, ", e);
                
                StringBuilder sb = new StringBuilder();
                
                sb.append(e.toString()).append(CHANGE_LINE);
                
                for(StackTraceElement st: e.getStackTrace()) {
                    sb.append(TAB_INDENT)
                        .append(st.toString())
                        .append(CHANGE_LINE);
                }
                
                byte[] bytes = sb.toString().getBytes();
                HttpEncryptService.setResponseData(bytes, response, false);
            }
    

    相关文章

      网友评论

          本文标题:Spring 实现一个支持 分发 转发的 rest接口

          本文链接:https://www.haomeiwen.com/subject/zaqbeqtx.html