做后端的同学都知道,XSS过滤器,防sql注入过滤器等是常用的 。关于什么是XSS攻击,网上的说法很多,自己百度一下吧。我们只需要加入一个xss过滤器就可以了。但大部分的文章都是针对普通的get/post请求进行的参数xss过滤,如果是post+appliccation/json等提交方式就显得无能为力了。很显然,对于普通的get/post,就按照之前的方式过滤参数就可以了(对于multipart等文件上传之类的参数做校验,这里暂时没有提供)。
为了方便起见,首先定义一些常量吧:
package cn.wjp.mydaily.common.filter;
public class HttpConst {
/**
* 几种常见的Content-Type
*/
public static final String FORM_URLENCODED_CONTENT_TYPE ="application/x-www-form-urlencoded";
public static final String JSON_CONTENT_TYPE = "application/json";
public static final String MULTIPART_CONTENT_TYPE = "multipart/form-data";
/**
* 常见的post/get请求方式
*/
public static final String POST_METHOD = "post";
public static final String GET_METHOD = "get";
public static final String OPTIONS_METHOD = "options";
}
接下来需要分别针对不同的表单提交类型写包装类,对于普通的get/post请求的包装类:
package cn.wjp.mydaily.common.filter;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletRequestWrapper;
import java.io.IOException;
import java.util.Enumeration;
import java.util.HashMap;
import java.util.Map;
import java.util.Vector;
/**
* 适用于普通get post请求 不包含multipart/form-data application/json等请求
* 缓存请求参数 重写获取参数的方法
*/
public class HttpServletRequestNormalWrapper extends HttpServletRequestWrapper {
private Map<String, String[]> parameterMap = new HashMap<>(); // 所有参数的Map集合
public HttpServletRequestNormalWrapper(HttpServletRequest request){
super(request);
Enumeration params = request.getParameterNames();//获得所有请求参数名
StringBuffer paramsValue = new StringBuffer("");
while (params.hasMoreElements()) {
String name = params.nextElement().toString(); //得到参数名
String[] value = request.getParameterValues(name);//得到参数对应值
parameterMap.put(name,value);
}
}
/**
* 获取所有参数名
*
* @return 返回所有参数名
*/
@Override
public Enumeration<String> getParameterNames() {
Vector<String> vector = new Vector<String>(parameterMap.keySet());
return vector.elements();
}
/**
* 获取指定参数名的值,如果有重复的参数名,则返回第一个的值 接收一般变量 ,如text类型
*
* @param name 指定参数名
* @return 指定参数名的值
*/
@Override
public String getParameter(String name) {
String[] values = parameterMap.get(name);
if(values==null||values.length==0){
return null;
}
return values[0];
//如果有多个参数值的,请放开该注释 我这里没有做这么细致
/*StringBuffer sb = new StringBuffer();
for(int i=0;i<values.length;i++){
if(i==values.length-1){
sb.append(values[i]);
}else{
sb.append(values[i]).append(",");
}
}
return sb.toString();*/
}
/**
* 获取指定参数名的所有值的数组,如:checkbox的所有数据
* 接收数组变量 ,如checkobx类型
*/
@Override
public String[] getParameterValues(String name) {
return parameterMap.get(name);
}
@Override
public Map<String, String[]> getParameterMap() {
return parameterMap;
}
public void setParameterMap(Map<String, String[]> parameterMap) {
this.parameterMap = parameterMap;
}
}
对于post+application/json的请求,对应的包装类:
package cn.wjp.mydaily.common.filter;
import javax.servlet.ReadListener;
import javax.servlet.ServletInputStream;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletRequestWrapper;
import java.io.*;
/**
* 将body中的数据缓存起来 重写getInputStream getReader 等方法 适用于application/json的post请求
*/
public class HttpServletRequestBodyReaderWrapper extends HttpServletRequestWrapper{
private String body ="{}";//缓存请求体的内容
public HttpServletRequestBodyReaderWrapper(HttpServletRequest request) throws IOException {
super(request);
StringBuilder stringBuilder = new StringBuilder("");
BufferedReader bufferedReader = null;
try {
InputStream inputStream = request.getInputStream();
if (inputStream != null) {
bufferedReader = new BufferedReader(new InputStreamReader(inputStream));
char[] charBuffer = new char[1024];
int bytesRead = -1;
while ((bytesRead = bufferedReader.read(charBuffer)) > 0) {
stringBuilder.append(charBuffer, 0, bytesRead);
}
}
} catch (IOException ex) {
throw ex;
} finally {
if (bufferedReader != null) {
try {
bufferedReader.close();
} catch (IOException ex) {
throw ex;
}
}
}
body = stringBuilder.toString();
}
@Override
public ServletInputStream getInputStream() throws IOException {
final ByteArrayInputStream byteArrayInputStream = new ByteArrayInputStream(body.getBytes("utf-8"));
ServletInputStream servletInputStream = new ServletInputStream() {
public boolean isFinished() {
return false;
}
public boolean isReady() {
return true;
}
public void setReadListener(ReadListener readListener) {}
public int read() throws IOException {
return byteArrayInputStream.read();
}
};
return servletInputStream;
}
@Override
public BufferedReader getReader() throws IOException {
return new BufferedReader(new InputStreamReader(this.getInputStream()));
}
public String getBody() {
return this.body;
}
public void setBody(String body) {
this.body = body;
}
}
接下来就是校验啦:
package cn.wjp.mydaily.common.filter;
import com.alibaba.fastjson.JSON;
import com.alibaba.fastjson.TypeReference;
import javax.servlet.*;
import javax.servlet.FilterConfig;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
import java.io.IOException;
import java.util.Arrays;
import java.util.HashMap;
import java.util.Map;
public class XssFilter implements Filter {
public void doFilter(ServletRequest req, ServletResponse res, FilterChain chain) throws IOException, ServletException {
HttpServletRequest request =(HttpServletRequest)req;
HttpServletResponse response =(HttpServletResponse)res;
String contentType = request.getContentType();//获取contentType请求头
String method = request.getMethod();//获取请求方法 post/get
//1 处理get请求 get请求的Content-Type一般为application/x-www-form-urlencoded 或者 text/html
if(method.trim().equalsIgnoreCase(HttpConst.GET_METHOD)){
HttpServletRequestNormalWrapper wrapperRequest = new HttpServletRequestNormalWrapper(request);
Map<String, String[]> parameterMap = wrapperRequest.getParameterMap();
parameterMap =cleanXSSForNormalRequest(parameterMap);
wrapperRequest.setParameterMap(parameterMap);
chain.doFilter(wrapperRequest, response);
return;
}
//2 处理post请求 只处理application/x-www-form-urlencoded application/json,对于multipart/form-data,直接放行
if(method.trim().equalsIgnoreCase(HttpConst.POST_METHOD)){
if(contentType.trim().toLowerCase().contains(HttpConst.MULTIPART_CONTENT_TYPE)){
chain.doFilter(request, response);
return;
}
//处理application/x-www-form-urlencoded
if(contentType.trim().toLowerCase().contains(HttpConst.FORM_URLENCODED_CONTENT_TYPE)){
HttpServletRequestNormalWrapper wrapperRequest = new HttpServletRequestNormalWrapper(request);
Map<String, String[]> parameterMap = wrapperRequest.getParameterMap();
parameterMap =cleanXSSForNormalRequest(parameterMap);
wrapperRequest.setParameterMap(parameterMap);
chain.doFilter(wrapperRequest, response);
return;
}
//处理application/json
if(contentType.trim().toLowerCase().contains(HttpConst.JSON_CONTENT_TYPE)){
HttpServletRequestBodyReaderWrapper requestWrapper = new HttpServletRequestBodyReaderWrapper(request);
String body = requestWrapper.getBody();
body =cleanXSSForPostJsonRequest(body);
requestWrapper.setBody(body);
chain.doFilter(requestWrapper, response);
return;
}
}
chain.doFilter(request, response);
return;
}
public String cleanXSS(String value) {
if(value==null||value.trim().isEmpty()){
return null;
}
value = value.replaceAll("<", "& lt;").replaceAll(">", "& gt;");
value = value.replaceAll("\\(", "& #40;").replaceAll("\\)", "& #41;");
value = value.replaceAll("'", "& #39;");
value = value.replaceAll("\"", "& #34;");
value = value.replaceAll("`", "");
value = value.replaceAll("eval\\((.*)\\)", "");
value = value.replaceAll("[\\\"\\\'][\\s]*javascript:(.*)[\\\"\\\']", "\"\"");
value = value.replaceAll("script", "");
return value;
}
/**
* 普通的post/get请求
* @param parameterMap
*/
public Map<String,String[]> cleanXSSForNormalRequest(Map<String,String[]> parameterMap){
Map<String,String[]> cleanMap = new HashMap<>();
if(parameterMap==null||parameterMap.size()==0){
return cleanMap;
}
for (Map.Entry<String,String[]> entry : parameterMap.entrySet()) {
String key = entry.getKey();
String[] value = entry.getValue();
String cleanKey = cleanXSS(key);
String[] cleanValue = null;
if(value!=null&&value.length>0){
cleanValue = new String[value.length];
for(int i=0;i<value.length;i++){
cleanValue[i]=cleanXSS(value[i]);
}
}
cleanMap.put(cleanKey,cleanValue);
}
//打印用
StringBuffer printStr = new StringBuffer();
for (Map.Entry<String,String[]> entry1 : cleanMap.entrySet()) {
printStr.append(entry1.getKey()).append("=").append(Arrays.asList(entry1.getValue())).append("&");
}
System.out.println("XssFilter:发送的请求参数:"+JSON.toJSONString(printStr));
return cleanMap;
}
/**
* post的application/json请求
* @param body
*/
public String cleanXSSForPostJsonRequest(String body){
String cleanBody = "{}";
if(body==null||body.trim().isEmpty()||body.trim().equalsIgnoreCase("{}")||!body.trim().contains(":")){
return cleanBody;
}
Map<String,Object> map = JSON.parseObject(body,new TypeReference<Map<String,Object>>(){});
if(map==null||map.size()==0){
return cleanBody;
}
Map<String,Object> cleanMap = new HashMap<>();
for (Map.Entry<String,Object> entry : map.entrySet()) {
String key = entry.getKey();
Object value = entry.getValue();
String valueStr = String.valueOf(value);
if(valueStr==null||valueStr.trim().isEmpty()||valueStr.trim().equalsIgnoreCase("null")){
valueStr = null;
}
cleanMap.put(cleanXSS(key),cleanXSS(valueStr));
}
cleanBody = JSON.toJSONString(cleanMap);
System.out.println("XssFilter:发送的请求参数:"+cleanBody);
return cleanBody;
}
@Override
public void destroy() {
}
@Override
public void init(FilterConfig arg0) {
}
}