继承 Filter 基类 OncePerRequestFilter 保证每个请求转发执行一次
public class MyAuthenticationProcessingFilter extends OncePerRequestFilter {
protected void doFilterInternal(HttpServletRequest request, HttpServletResponse response, FilterChain filterChain)
throws ServletException, IOException {
filterChain.doFilter(requestWrapper, response);
}
出现的问题:在 filter 中消费了 Request 中的 InputStream
导致后续的过滤器中无法调用 Request
解决方法:定义一个 HttpServletRequestWrapper
类,将输入流字节数据读取出来,以供使用,重新 getInputStream() 方法,将输入流字节数组重新封装成 ServletInputStream
输入流即可,注意字符编码
ServletRequestWrapper.java
public class ServletRequestWrapper extends HttpServletRequestWrapper {
private byte[] body;
private String requestParam;
/**
* Constructs a request object wrapping the given request.
* @Description: 将 request 中的流信息读取出来供外部使用,将流缓存起来,传到下一个 filter 中
* @param request The request to wrap
* @throws IllegalArgumentException if the request is null
*/
public ServletRequestWrapper(HttpServletRequest request) {
super(request);
requestParam = HttpUtil.getBodyString(request);
body = requestParam.getBytes(Charset.forName("utf-8"));
}
@Override
public BufferedReader getReader() throws IOException {
return new BufferedReader(new InputStreamReader(getRequest().getInputStream(), Charset.forName("UTF-8")));
}
@Override
public ServletInputStream getInputStream() throws IOException {
return new CustomServletInputStream();
}
private class CustomServletInputStream extends ServletInputStream {
private ByteArrayInputStream inputStream = new ByteArrayInputStream(body);
@Override
public boolean isFinished() {
return false;
}
@Override
public boolean isReady() {
return false;
}
@Override
public void setReadListener(ReadListener listener) {
}
@Override
public int read() throws IOException {
return inputStream.read();
}
}
public String getRequestParam() {
return requestParam;
}
}
HttpUtil.java
public class HttpUtil {
public static String getBodyString(ServletRequest request) {
BufferedReader bufferedReader = null;
InputStream inputStream = null;
StringBuilder sb = new StringBuilder("");
try {
inputStream = request.getInputStream();
bufferedReader = new BufferedReader(new InputStreamReader(inputStream, Charset.forName("utf-8")));
String line = "";
while ((line = bufferedReader.readLine()) != null) {
sb.append(line);
}
} catch (IOException e) {
e.printStackTrace();
} finally {
if (bufferedReader != null) {
try {
bufferedReader.close();
} catch (IOException e) {
e.printStackTrace();
}
}
if (inputStream != null) {
try {
inputStream.close();
} catch (IOException e) {
e.printStackTrace();
}
}
}
return sb.toString();
}
}
网友评论