美文网首页
【SpringBoot】2022-03-26【自定义请求转发、分

【SpringBoot】2022-03-26【自定义请求转发、分

作者: SweetMojito | 来源:发表于2022-03-26 11:24 被阅读0次

需求背景:

当后端需要部署多区域或者多实例,而前端界面是一个,往往通过前端的区域筛选器来切换访问对应区域的后端实例时。可以通过前端直接访问不同区域的Ip,但这样新增区域或者后端变化时不够灵活;另外,也可以通过访问形如注册中心的转发服务,转发服务根据请求中的区域字段获取对应后端的地址进而转发到对应区域的后端,拿到接口返回数据后返回给前端。这样前端只需要配置一个访问地址,即转发服务的地址。

各区域的后端实例可以在启动时,将本机服务信息注册到zk中,这样转发服务就可以从zk中获取提供服务的后端地址。

主要流程:

  1. 前端的所有请求,都请求到转发服务

  2. 转发服务根据请求头中的区域进行转发

  3. 各区域提供服务的ip和端口通过zk中获取,zk扮演注册中心的角色

  4. 各区域后端启动的时候往zk中注册服务

主要代码:

对request对象进行包装:CustomHttpServletRequestWrapper.class

package com.tiger.web.common;

import javax.servlet.ReadListener;
import javax.servlet.ServletInputStream;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletRequestWrapper;
import java.io.ByteArrayInputStream;
import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.io.InputStream;

/**
 * 包装request避免request中的输入流被多次读时报错
 *
 * @author tiger 2022/3/18
 */
public class CustomHttpServletRequestWrapper extends HttpServletRequestWrapper {
    private final byte[] body;
    public CustomHttpServletRequestWrapper(HttpServletRequest request) throws IOException {
        super(request);
        InputStream reader = request.getInputStream();
        try (ByteArrayOutputStream writer = new ByteArrayOutputStream()) {
            int read;
            byte[] buf = new byte[Integer.parseInt(request.getHeader("content-length"))];
            while ((read = reader.read(buf)) != -1) {
                writer.write(buf, 0, read);
            }
            this.body = writer.toByteArray();
        }
    }

    @Override
    public ServletInputStream getInputStream() {
        ByteArrayInputStream byteArrayInputStream = new ByteArrayInputStream(body);
        return new ServletInputStream() {
            @Override
            public int read() {
                return byteArrayInputStream.read();
            }
            @Override
            public void setReadListener(ReadListener listener) {
            }
            @Override
            public boolean isReady() {
                return false;
            }
            @Override
            public boolean isFinished() {
                return false;
            }
        };
    }
}

过滤器,拦截所有前端请求并进行转发,DispatcherFilter.class

package com.tiger.web.common;

import com.alibaba.fastjson.JSON;
import com.tiger.common.constant.HttpStatus;
import com.tiger.common.domain.OResult;
import com.tiger.web.service.DispatcherService;
import lombok.Data;
import lombok.NoArgsConstructor;
import lombok.extern.slf4j.Slf4j;
import org.I0Itec.zkclient.ZkClient;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.http.HttpMethod;
import org.springframework.stereotype.Component;

import javax.annotation.Resource;
import javax.servlet.Filter;
import javax.servlet.FilterChain;
import javax.servlet.FilterConfig;
import javax.servlet.ServletException;
import javax.servlet.ServletRequest;
import javax.servlet.ServletResponse;
import javax.servlet.annotation.WebFilter;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
import java.io.IOException;
import java.util.List;

/**
 * @author tiger 2022/3/26
 */
@Component
@WebFilter(urlPatterns = {"/*"}, filterName = "authFilter")
@NoArgsConstructor
@Slf4j
public class DispatcherFilter implements Filter {
    // default项目后端注册的zk路径
    private static final String REGISTRY_SERVER_PATH = GlobalConstants.ZK_NODE_SEPARATOR + GlobalConstants.DEFAULT_PROJECT + "/server";
    // other项目后端注册的zk路径
    private static final String OTHER_REGISTRY_SERVER_PATH = GlobalConstants.ZK_NODE_SEPARATOR + GlobalConstants.OTHER_PROJECT + "/server";

    // 特殊的请求路径转发到指定的后端实例
    @Value("${dispatcher.special.url}")
    List<String> dispatcherSpecialUrlList;
    @Value("${dispatcher.special.zoneCode}")
    String specialZoneCode;
    @Resource
    ZkClient zkClient;
    @Resource
    DispatcherService dispatcherService;

    @Override
    public void init(FilterConfig filterConfig) {
    }

    @Override
    public void doFilter(ServletRequest req, ServletResponse res, FilterChain chain) throws IOException, ServletException {
        CustomHttpServletRequestWrapper requestWrapper = new CustomHttpServletRequestWrapper((HttpServletRequest) req);
        HttpServletResponse response = (HttpServletResponse) res;
        response.setHeader("Access-Control-Allow-Methods", "POST, GET, OPTIONS, DELETE,PUT");
        response.setHeader("Access-Control-Max-Age", "3600");
        response.setHeader("Access-Control-Allow-Headers", "x-requested-with");

        response.setCharacterEncoding("utf-8");
        response.setContentType("application/json;charset=UTF-8");

        // user check
        if (authCheck(requestWrapper, response)) {
            return;
        }

        // 禁止favicon.ico图标请求
        String iconRequest = "/favicon.ico";
        if (iconRequest.equals(requestWrapper.getRequestURI())) {
            return;
        }

        Dispatcher dispatcher = getDispatcherUrl(requestWrapper, response);
        if (dispatcher.getIsReturn()) {
            return;
        }

        String forward = "http://" + dispatcher.getTarget() + requestWrapper.getRequestURI();
        log.info("========== accept request {} {}", requestWrapper.getMethod(), forward);
        dispatcherService.requestService(requestWrapper, response, HttpMethod.valueOf(requestWrapper.getMethod()), forward);
        log.info("========== finished request {} {}", requestWrapper.getMethod(), forward);
    }

    /**
     * 获取转发地址
     *
     * @param myRequestWrapper
     * @param response
     * @return
     */
    private Dispatcher getDispatcherUrl(HttpServletRequest myRequestWrapper, HttpServletResponse response) throws IOException {
        Dispatcher result = new Dispatcher();
        result.setIsReturn(true);
        // 获取注册中心数据
        String zkChroot = myRequestWrapper.getRequestURI().contains("/other") ? OTHER_REGISTRY_SERVER_PATH : REGISTRY_SERVER_PATH;
        List<String> supportZoneList = zkClient.getChildren(zkChroot);
        log.info("获取注册中心支持的机房: {}", JSON.toJSONString(supportZoneList));

        String zoneCode = myRequestWrapper.getHeader("zoneCode");
        // 特殊处理的区域码
        if ("all".equals(zoneCode)) {
            zoneCode = specialZoneCode;
        }

        if (!supportZoneList.contains(zoneCode)) {
            response.setStatus(HttpStatus.FORBIDDEN);
            response.getWriter().write(JSON.toJSONString(OResult.fail(HttpStatus.UNAUTHORIZED, String.format("不支持该机房%s", zoneCode))));
            return result;
        }

        // 随机取
        List<String> nodes = zkClient.getChildren(zkChroot + GlobalConstants.ZK_NODE_SEPARATOR + zoneCode);
        if (nodes.size() == 0) {
            response.setStatus(HttpStatus.FORBIDDEN);
            response.getWriter().write(JSON.toJSONString(OResult.fail(HttpStatus.UNAUTHORIZED, String.format("当前机房(%s)无可用服务", zoneCode))));
            return result;
        }
        String target = nodes.get((int) (Math.random() * nodes.size()));

        if (dispatcherSpecialUrlList.contains(myRequestWrapper.getRequestURI())) {
            log.info("Special url");
            List<String> spNodes = zkClient.getChildren(zkChroot + GlobalConstants.ZK_NODE_SEPARATOR + specialZoneCode);
            if (spNodes.size() == 0) {
                response.setStatus(HttpStatus.FORBIDDEN);
                response.getWriter().write(JSON.toJSONString(OResult.fail(HttpStatus.UNAUTHORIZED, String.format("当前机房(%s)无可用服务", zoneCode))));
                return result;
            }
            // 随机获取
            int spIndex = (int) (Math.random() * spNodes.size());
            target = spNodes.get(spIndex);
        }
        result.setIsReturn(false);
        result.setTarget(target);
        return result;
    }

    /**
     * 用户校验等
     *
     * @param myRequestWrapper
     * @param response
     * @return 是否结束方法
     */
    private Boolean authCheck(HttpServletRequest myRequestWrapper, HttpServletResponse response) throws IOException {
        return true;
    }

    @Data
    static class Dispatcher {
        private Boolean isReturn;
        private String target;
    }

    @Override
    public void destroy() {
    }
}

具体的转发实现,DispatcherService.class

package com.tiger.web.service;

import com.alibaba.fastjson.JSON;
import com.google.common.collect.Lists;
import com.tiger.web.common.CustomHttpServletRequestWrapper;
import com.tiger.common.constant.HttpStatus;
import com.tiger.common.domain.OResult;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.io.IOUtils;
import org.apache.commons.lang3.StringUtils;
import org.apache.tomcat.util.http.fileupload.FileItem;
import org.apache.tomcat.util.http.fileupload.RequestContext;
import org.apache.tomcat.util.http.fileupload.disk.DiskFileItemFactory;
import org.apache.tomcat.util.http.fileupload.servlet.ServletFileUpload;
import org.apache.tomcat.util.http.fileupload.servlet.ServletRequestContext;
import org.springframework.core.io.ByteArrayResource;
import org.springframework.http.HttpEntity;
import org.springframework.http.HttpHeaders;
import org.springframework.http.HttpMethod;
import org.springframework.http.MediaType;
import org.springframework.http.ResponseEntity;
import org.springframework.stereotype.Component;
import org.springframework.util.LinkedMultiValueMap;
import org.springframework.util.MultiValueMap;
import org.springframework.web.client.HttpClientErrorException;
import org.springframework.web.client.RestTemplate;

import javax.annotation.Resource;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
import java.io.File;
import java.io.IOException;
import java.io.PrintWriter;
import java.net.URLDecoder;
import java.nio.charset.StandardCharsets;
import java.util.Enumeration;
import java.util.List;

/**
 * @author tiger 2022/3/18
 */
@Component
@Slf4j
public class DispatcherService {
    @Resource
    RestTemplate restTemplate;
    private static final long MAX_SIZE = 10 * 1024 * 1024 * 1024L;

    public void requestService(HttpServletRequest request, HttpServletResponse response, HttpMethod method, String uri) throws IOException {
        if (ServletFileUpload.isMultipartContent(request)) {
            // 关键点:文件上传的转发
            doFileUploadDispatch(request, response, method, uri);
        } else {
            // 其他转发
            doDispatch(request, response, method, uri);
        }
    }

    /**
     * 文件上传
     */
    private void doFileUploadDispatch(HttpServletRequest request, HttpServletResponse response, HttpMethod method, String uri) throws IOException {
        log.info("{}", Thread.currentThread().getStackTrace()[1].getMethodName());
        CustomHttpServletRequestWrapper requestWrapper = (CustomHttpServletRequestWrapper) request;
        DiskFileItemFactory factory = new DiskFileItemFactory();
        factory.setSizeThreshold(4096);
        factory.setRepository(new File("./uploadFileTemp"));
        ServletFileUpload fileUpload = new ServletFileUpload(factory);
        fileUpload.setHeaderEncoding("utf-8");
        fileUpload.setSizeMax(MAX_SIZE);

        List<FileItem> fileItemList;
        try {
            RequestContext requestContext = new ServletRequestContext(requestWrapper);
            fileItemList = fileUpload.parseRequest(requestContext);
        } catch (Exception exception) {
            exception.printStackTrace();

            response.setStatus(HttpStatus.ERROR);
            response.getWriter().write(JSON.toJSONString(OResult.fail(HttpStatus.ERROR, exception.getMessage())));
            return;
        }

        if (fileItemList == null || fileItemList.size() == 0) {
            response.setStatus(HttpStatus.ERROR);
            response.getWriter().write(JSON.toJSONString(OResult.fail(HttpStatus.ERROR, "没有文件")));
            return;
        }
        List<Object> fileList = Lists.newArrayList();
        for (final FileItem fileItem : fileItemList) {
            ByteArrayResource byteArr = new ByteArrayResource(fileItem.get()) {
                @Override
                public String getFilename() throws IllegalStateException {
                    return fileItem.getName();
                }
            };
            fileList.add(byteArr);
        }

        // 进行转发
        MultiValueMap<String, Object> from = new LinkedMultiValueMap<>();
        // 添加上传的文件
        for (FileItem fileItem : fileItemList) {
            if (fileItem.getContentType() == null) {
                // 普通参数
                from.add(fileItem.getFieldName(), fileItem.getString());
            } else {
                // 文件
                from.addAll(fileItem.getFieldName(), fileList);
            }
        }

        // 请求URL
        if (!StringUtils.isEmpty(request.getQueryString())) {
            uri = String.format("%s?%s", uri, URLDecoder.decode(request.getQueryString(), "utf-8"));
        }
        // 请求头
        HttpHeaders headers = new HttpHeaders();
        headers.setContentType(MediaType.MULTIPART_FORM_DATA);
        Enumeration<String> headerNames = request.getHeaderNames();
        while (headerNames.hasMoreElements()) {
            String name = headerNames.nextElement();
            headers.add(name, request.getHeader(name));
        }
        HttpEntity<MultiValueMap<String, Object>> files = new HttpEntity<>(from, headers);

        try {
            ResponseEntity<String> responseEntity = restTemplate.exchange(uri, method, files, String.class);
            if (responseEntity.hasBody()) {
                // 设置响应信息
                response.setStatus(responseEntity.getStatusCodeValue());
                response.getWriter().write(JSON.toJSONString(responseEntity.getBody()));
            }
        } catch (HttpClientErrorException httpClientErrorException) {
            httpClientErrorException.printStackTrace();

            response.setStatus(httpClientErrorException.getRawStatusCode());
            response.getWriter().write(JSON.toJSONString(OResult.fail(httpClientErrorException.getRawStatusCode(), httpClientErrorException.getMessage())));
        } catch (Exception exception) {
            exception.printStackTrace();

            response.setStatus(HttpStatus.ERROR);
            response.getWriter().write(JSON.toJSONString(OResult.fail(HttpStatus.ERROR, "转发请求异常了" + exception.getMessage())));
        }
    }

    /**
     * 非文件请求
     */
    private void doDispatch(HttpServletRequest request, HttpServletResponse response, HttpMethod method, String uri) throws IOException {
        String requestBody = IOUtils.toString(request.getInputStream(), StandardCharsets.UTF_8);
        // 请求体
        Object body = null;
        if (!StringUtils.isEmpty(requestBody)) {
            body = JSON.parse(requestBody);
        }
        // 请求头
        HttpHeaders headers = new HttpHeaders();
        Enumeration<String> headerNames = request.getHeaderNames();
        while (headerNames.hasMoreElements()) {
            String name = headerNames.nextElement();
            headers.add(name, request.getHeader(name));
        }
        // 请求URL
        if (!StringUtils.isEmpty(request.getQueryString())) {
            uri = String.format("%s?%s", uri, URLDecoder.decode(request.getQueryString(), "utf-8"));
        }
        HttpEntity<Object> httpEntity = new HttpEntity<>(body, headers);
        ResponseEntity<OResult> exchange;
        PrintWriter writer = response.getWriter();
        try {
            // 发送请求
            exchange = restTemplate.exchange(uri, method, httpEntity, OResult.class);
            // 设置响应信息
            response.setStatus(exchange.getStatusCodeValue());

            writer.write(JSON.toJSONString(exchange.getBody()));
        } catch (HttpClientErrorException httpClientErrorException) {
            httpClientErrorException.printStackTrace();

            response.setStatus(httpClientErrorException.getRawStatusCode());
            writer.write(JSON.toJSONString(OResult.fail(httpClientErrorException.getRawStatusCode(), httpClientErrorException.getMessage())));
        } catch (Exception exception) {
            exception.printStackTrace();

            response.setStatus(HttpStatus.ERROR);
            writer.write(JSON.toJSONString(OResult.fail(HttpStatus.ERROR, "转发请求异常了" + exception.getMessage())));
        }
    }
}

项目中定义的返回实体,OResult.class

public class OResult<T> implements Serializable {
    private static final long serialVersionUID = 1L;
    private Integer status;
    private String msg;
    private T data;
    private String[] stack;

    public OResult() {
    }

    public static <T> OResult<T> success() {
        OResult<T> result = new OResult();
        result.setMsg("SUCCESS");
        result.setStatus(200);
        result.setData((Object)null);
        return result;
    }

    public static <T> OResult<T> fail() {
        OResult<T> result = new OResult();
        result.setMsg("请求处理出错");
        return result;
    }

    public static <T> OResult<T> success(T t) {
        OResult<T> result = new OResult();
        result.setMsg("SUCCESS");
        result.setStatus(200);
        result.setData(t);
        return result;
    }
    public static <T> OResult<T> fail(String msg) {
        OResult<T> result = new OResult();
        result.setMsg(msg);
        return result;
    }

    public static <T> OResult<T> fail(int status, String msg) {
        OResult<T> result = new OResult();
        result.setStatus(status);
        result.setMsg(msg);
        return result;
    }

    public static <T> OResult<T> fail(Throwable e) {
        OResult<T> result = new OResult();
        result.setMsg(StringUtils.defaultString(e.getMessage(), e.toString()));
        result.setStack(ExceptionUtils.getStackFrames(e));
        return result;
    }

附后端启动时将服务信息注册到zk的实现,ServerRegistry.class

package com.tiger.web.common;

import com.tiger.common.constant.GlobalConstants;
import lombok.extern.slf4j.Slf4j;
import org.I0Itec.zkclient.ZkClient;
import org.I0Itec.zkclient.exception.ZkNodeExistsException;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.boot.ApplicationArguments;
import org.springframework.boot.ApplicationRunner;
import org.springframework.stereotype.Component;

import javax.annotation.Resource;
import java.net.InetAddress;
import java.net.UnknownHostException;

/**
 * 往zk中注册服务
 *
 * @author tiger 2022/2/19
 */
@Component
@Slf4j
public class ServerRegistry implements ApplicationRunner {
    @Resource
    ZkClient zkClient;
    @Value("${zoneCode}")
    private String zoneCode;
    @Value("${server.port}")
    private String port;

    private static final String REGISTRY_SERVER_PATH = GlobalConstants.ZK_NODE_SEPARATOR + GlobalConstants.DEFAULT_PROJECT + "/server";
    private static final long REGISTRY_TIME_OUT_MS = 30000;

    @Override
    public void run(ApplicationArguments args) throws Exception {
        // TODO:临时解决window启动不注册,单侧跑不出来问题
        String osName = System.getProperty("os.name").toLowerCase();
        if (osName.startsWith("win")) {
            log.info("Win system skip registry.");
            return;
        }
        // 创建根节点
        String zkNode = getIp() + ":" + port;
        log.info("Start registry server {}.", zkNode);
        if (!zkClient.exists(REGISTRY_SERVER_PATH)) {
            zkClient.createPersistent(REGISTRY_SERVER_PATH, true);
        }
        if (!zkClient.exists(REGISTRY_SERVER_PATH + GlobalConstants.ZK_NODE_SEPARATOR + zoneCode)) {
            zkClient.createPersistent(REGISTRY_SERVER_PATH + GlobalConstants.ZK_NODE_SEPARATOR + zoneCode);
        }
        // 注册服务,创建临时节点
        long nextTime = System.currentTimeMillis() + REGISTRY_TIME_OUT_MS;
        boolean registry = false;
        while (nextTime - System.currentTimeMillis() > 0 && !registry) {
            try {
                zkClient.createEphemeral(REGISTRY_SERVER_PATH + GlobalConstants.ZK_NODE_SEPARATOR + zoneCode + GlobalConstants.ZK_NODE_SEPARATOR + zkNode);
                registry = true;
                log.info("Registry success.");
                break;
            } catch (ZkNodeExistsException zkNodeExistsException) {
                log.error("zkNode exist cause registry server failure.");
                Thread.sleep(3000);
            } catch (Exception exception) {
                exception.printStackTrace();
                log.error("registry server failure.", exception);
                Thread.sleep(3000);
            }
        }
        if (!registry) {
            // 若注册不成功,直接退出服务
            log.error("Fatal error: registry server {} failure and exit.", zkNode);
            System.exit(1);
        }
    }

    private String getIp() {
        try {
            return InetAddress.getLocalHost().getHostAddress();
        } catch (UnknownHostException e) {
            log.error("Fatal error: get ip error and exit.");
            System.exit(1);
            e.printStackTrace();
        }
        return null;
    }
}

主要依赖:

<!--zk client-->
<dependency>
      <groupId>com.101tec</groupId>
      <artifactId>zkclient</artifactId>
</dependency>

参考文献:

本文有很多处可以以更好的方式来实现,比如:zk可以用watch等。
待补充

相关文章

网友评论

      本文标题:【SpringBoot】2022-03-26【自定义请求转发、分

      本文链接:https://www.haomeiwen.com/subject/jlojjrtx.html