美文网首页Springboot
dubbo rpc调用参数校验

dubbo rpc调用参数校验

作者: 上重楼 | 来源:发表于2018-08-10 11:20 被阅读207次

    使用spring的时候http调用参数校验还是很方便的,只是我们rpc用的比较多,然后就有了这个了。

    其实实现很简单,实现一个dubbo的Filter 然后在这里根据反射获取参数的注解使用javax.validation的注解做各种操作就行了。要使用的项目需要加入自定义的Filter
    类似:


    image.png

    加上paramValidationFilter 多个Filter用英文逗号分割。
    加这个的意思是,这个项目的全局rpc调用Filter加入自定义的参数校验。 就算你的项目不是使用xml的配置方法,也可有其他类似的行为。需要注意多个Filter的优先级

    我们的rpc都是返回一个固定的包装对象 ResultData<T> 所以会返回一个有错误状态码和错误信息的包装ResultData给调用方
    如果返回类型不为ResultData会抛出我们自定义的异常

    上实现代码:

    package com.jx.common.dubbo.filter;
    
    import com.alibaba.dubbo.common.extension.Activate;
    import com.alibaba.dubbo.rpc.*;
    import com.google.common.collect.Maps;
    import com.jx.common.exception.BusinessException;
    import com.jx.common.model.ResultData;
    import lombok.Data;
    import lombok.extern.slf4j.Slf4j;
    
    import javax.validation.ConstraintViolation;
    import javax.validation.Valid;
    import javax.validation.Validation;
    import javax.validation.Validator;
    import javax.validation.constraints.*;
    import javax.validation.groups.Default;
    import java.lang.annotation.Annotation;
    import java.lang.reflect.Method;
    import java.lang.reflect.Parameter;
    import java.math.BigDecimal;
    import java.math.BigInteger;
    import java.util.Arrays;
    import java.util.List;
    import java.util.Map;
    import java.util.Set;
    import java.util.stream.Collectors;
    
    /**
     * @author 周广
     **/
    @Slf4j
    @Activate(order = -100)
    public class ParamValidationFilter implements Filter {
        private static Validator validator = Validation.buildDefaultValidatorFactory().getValidator();
    
    
        /**
         * 对象实体校验
         *
         * @param obj 待校验对象
         * @param <T> 待校验对象的泛型
         * @return 校验结果
         */
        private static <T> ValidationResult validateEntity(T obj) {
            Set<ConstraintViolation<T>> set = validator.validate(obj, Default.class);
            return getValidationResult(set);
        }
    
        /**
         * 将校验结果转换返回对象
         *
         * @param set 错误信息set
         * @param <T> 校验对象的泛型
         * @return 校验结果
         */
        private static <T> ValidationResult getValidationResult(Set<ConstraintViolation<T>> set) {
            ValidationResult result = new ValidationResult();
            if (set != null && !set.isEmpty()) {
                result.setHasErrors(true);
                Map<String, String> errorMsg = Maps.newHashMap();
                for (ConstraintViolation<T> violation : set) {
                    errorMsg.put(violation.getPropertyPath().toString(), violation.getMessage());
                }
                result.setErrorMsg(errorMsg);
            }
            return result;
        }
    
        /**
         * 方法级别的参数验证 手撕o(╥﹏╥)o
         *
         * @param annotation 待验证的注解
         * @param param      待校验参数
         * @param paramName  参数名称
         * @return 校验结果
         */
        static ValidationResult validateMethod(Annotation annotation, Object param, String paramName) {
            ValidationResult result = new ValidationResult();
            Map<String, String> errorMsg = Maps.newHashMap();
            result.setErrorMsg(errorMsg);
    
            if (annotation instanceof DecimalMax) {
    
                if (param instanceof BigDecimal) {
                    BigDecimal value = new BigDecimal(((DecimalMax) annotation).value());
                    if (((BigDecimal) param).compareTo(value) > 0) {
                        result.setHasErrors(true);
                        if ("{javax.validation.constraints.DecimalMax.message}".equals(((DecimalMax) annotation).message())) {
                            errorMsg.put(paramName, param.getClass().getName() + " 值为:" + param + "大于" + value);
                        } else {
                            errorMsg.put(paramName, ((DecimalMax) annotation).message());
                        }
                    }
                } else if (param instanceof BigInteger) {
                    BigInteger value = new BigInteger(((DecimalMax) annotation).value());
                    if (((BigInteger) param).compareTo(value) > 0) {
                        result.setHasErrors(true);
                        if ("{javax.validation.constraints.DecimalMax.message}".equals(((DecimalMax) annotation).message())) {
                            errorMsg.put(paramName, param.getClass().getName() + " 值为:" + param + "大于" + value);
                        } else {
                            errorMsg.put(paramName, ((DecimalMax) annotation).message());
                        }
                    }
                } else {
                    result.setHasErrors(true);
                    errorMsg.put(paramName, "类型:" + param.getClass().getName() + "与注解:" + annotation.getClass().getName() + "不匹配");
                }
    
            } else if (annotation instanceof DecimalMin) {
                if (param instanceof BigDecimal) {
                    BigDecimal value = new BigDecimal(((DecimalMin) annotation).value());
                    if (((BigDecimal) param).compareTo(value) < 0) {
                        result.setHasErrors(true);
                        if ("{javax.validation.constraints.DecimalMin.message}".equals(((DecimalMin) annotation).message())) {
                            errorMsg.put(paramName, paramName + param.getClass().getName() + " 值为:" + param + "小于" + value);
                        } else {
                            errorMsg.put(paramName, ((DecimalMin) annotation).message());
                        }
                    }
                } else if (param instanceof BigInteger) {
                    BigInteger value = new BigInteger(((DecimalMin) annotation).value());
                    if (((BigInteger) param).compareTo(value) < 0) {
                        result.setHasErrors(true);
                        if ("{javax.validation.constraints.DecimalMin.message}".equals(((DecimalMin) annotation).message())) {
                            errorMsg.put(paramName, param.getClass().getName() + " 值为:" + param + "小于" + value);
                        } else {
                            errorMsg.put(paramName, ((DecimalMin) annotation).message());
                        }
                    }
                } else {
                    result.setHasErrors(true);
                    errorMsg.put(paramName, "类型:" + param.getClass().getName() + "与注解:" + annotation.getClass().getName() + "不匹配");
                }
            } else if (annotation instanceof Max) {
                long value = ((Max) annotation).value();
                if (Long.valueOf(param.toString()) > value) {
                    result.setHasErrors(true);
                    if ("{javax.validation.constraints.Max.message}".equals(((Max) annotation).message())) {
                        errorMsg.put(paramName, param.getClass().getName() + " 值为:" + param + "大于" + value);
                    } else {
                        errorMsg.put(paramName, ((Max) annotation).message());
                    }
                }
            } else if (annotation instanceof Min) {
                long value = ((Min) annotation).value();
                if (Long.valueOf(param.toString()) < value) {
                    result.setHasErrors(true);
                    if ("{javax.validation.constraints.Min.message}".equals(((Min) annotation).message())) {
                        errorMsg.put(paramName, param.getClass().getName() + " 值为:" + param + "小于" + value);
                    } else {
                        errorMsg.put(paramName, ((Min) annotation).message());
                    }
                }
    
            } else if (annotation instanceof NotNull) {
                if (param == null) {
                    result.setHasErrors(true);
                    if ("{javax.validation.constraints.NotNull.message}".equals(((NotNull) annotation).message())) {
                        errorMsg.put(paramName, "值为null");
                    } else {
                        errorMsg.put(paramName, ((NotNull) annotation).message());
                    }
                }
            } else if (annotation instanceof Size) {
                int value = Integer.valueOf(param.toString());
                if (value > ((Size) annotation).max()) {
                    result.setHasErrors(true);
                    if ("{javax.validation.constraints.Size.message}".equals(((Size) annotation).message())) {
                        errorMsg.put(paramName, param.getClass().getName() + " 值为:" + param + "大于" + ((Size) annotation).max());
                    } else {
                        errorMsg.put(paramName, ((Size) annotation).message());
                    }
                } else if (value < ((Size) annotation).min()) {
                    result.setHasErrors(true);
                    if ("{javax.validation.constraints.Size.message}".equals(((Size) annotation).message())) {
                        errorMsg.put(paramName, param.getClass().getName() + " 值为:" + param + "小于" + ((Size) annotation).min());
                    } else {
                        errorMsg.put(paramName, ((Size) annotation).message());
                    }
                    errorMsg.put(paramName, param.getClass().getName() + " 值为:" + param + "小于" + ((Size) annotation).min());
                }
            }
            return result;
        }
    
        @Override
        public Result invoke(Invoker<?> invoker, Invocation invocation) throws RpcException {
            Method method = null;
    
            for (Method m : invoker.getInterface().getDeclaredMethods()) {
                if (m.getName().equals(invocation.getMethodName()) && invocation.getArguments().length == m.getParameterCount()) {
                    Class[] invokerMethodParamClassList = invocation.getParameterTypes();
                    Class[] matchMethodParamClassList = m.getParameterTypes();
                    if (verifyClassMatch(invokerMethodParamClassList, matchMethodParamClassList)) {
                        method = m;
                        break;
                    }
                }
            }
    
            //如果找不到对应的方法就跳过参数校验
            if (method == null) {
                return invoker.invoke(invocation);
            }
    
    
            //一个参数可以有多个注解
            Annotation[][] paramAnnotation = method.getParameterAnnotations();
            //参数的class
            Class<?>[] paramClass = invocation.getParameterTypes();
            Object[] paramList = invocation.getArguments();
            //获取参数名称
            List<String> paramNameList = Arrays.stream(method.getParameters()).map(Parameter::getName).collect(Collectors.toList());
    
            for (int i = 0; i < paramList.length; i++) {
                Object param = paramList[i];
    
    
                Annotation[] annotations = paramAnnotation[i];
                if (annotations.length == 0) {
                    continue;
                }
    
    
                //循环注解处理,快速失败
                for (Annotation annotation : annotations) {
                    if (isJavaxValidationAnnotation(annotation)) {
                        ValidationResult result;
                        try {
                            result = validateMethod(annotation, param, paramNameList.get(i));
                        } catch (Exception e) {
                            log.error("参数校验异常", e);
                            return invoker.invoke(invocation);
                        }
                        if (result.isHasErrors()) {
                            if (ResultData.class.equals(method.getReturnType())) {
                                return generateFailResult(result.getErrorMsg().toString());
                            } else {
                                throw new BusinessException(result.getErrorMsg().toString());
                            }
                        }
                    } else if (annotation instanceof Valid) {
                        if (param == null) {
                            String errMsg = String.format("待校验对象:%s 不可为null", paramClass[i].getName());
                            if (ResultData.class.equals(method.getReturnType())) {
                                return generateFailResult(errMsg);
                            } else {
                                throw new BusinessException(errMsg);
                            }
                        }
                        ValidationResult result = validateEntity(param);
                        if (result.isHasErrors()) {
                            if (ResultData.class.equals(method.getReturnType())) {
                                return generateFailResult(result.getErrorMsg().toString());
                            } else {
                                throw new BusinessException(result.getErrorMsg().toString());
                            }
                        }
                    }
                }
            }
    
            return invoker.invoke(invocation);
        }
    
        private boolean verifyClassMatch(Class[] invokerMethodParamClassList, Class[] matchMethodParamClassList) {
            for (int i = 0; i < invokerMethodParamClassList.length; i++) {
                if (!invokerMethodParamClassList[i].equals(matchMethodParamClassList[i])) {
                    return false;
                }
            }
            return true;
        }
    
        /**
         * 构建失败返回对象
         *
         * @param errorMsg 错误信息
         * @return dubbo 返回对象
         */
        private Result generateFailResult(String errorMsg) {
            //请求参数非法
            return new RpcResult(new ResultData<>("0001", errorMsg));
        }
    
        /**
         * 判断是否为javax.validation.constraints的注解
         *
         * @param annotation 目标注解
         */
        private boolean isJavaxValidationAnnotation(Annotation annotation) {
            if (annotation instanceof AssertFalse) {
                return true;
            } else if (annotation instanceof AssertTrue) {
                return true;
            } else if (annotation instanceof DecimalMax) {
                return true;
            } else if (annotation instanceof DecimalMin) {
                return true;
            } else if (annotation instanceof Digits) {
                return true;
            } else if (annotation instanceof Future) {
                return true;
            } else if (annotation instanceof Max) {
                return true;
            } else if (annotation instanceof Min) {
                return true;
            } else if (annotation instanceof NotNull) {
                return true;
            } else if (annotation instanceof Null) {
                return true;
            } else if (annotation instanceof Past) {
                return true;
            } else if (annotation instanceof Pattern) {
                return true;
            } else if (annotation instanceof Size) {
                return true;
            }
            return false;
        }
    
        @Data
        static class ValidationResult {
    
            /**
             * 校验结果是否有错
             */
            private boolean hasErrors;
    
            /**
             * 校验错误信息
             */
            private Map<String, String> errorMsg;
        }
    }
    
    

    使用文档直接就从公司内部我写的wiki文档复制过来的, 送佛送到西
    使用方法很简单,如果需要使用参数校验需要在接口的方法上面加上一些注解,注意是接口! 实现类加了没用


    image.png

    这里我们要区分2类参数

    第一类:自定义对象 必须使用@Valid注解,标记这个对象需要进行校验
    第二类:基本类型、包装类型、Decimal类型、String 等等一般都是java自带的类型

    第二类可以使用:
    第一个:javax.validation.constraints.DecimalMax
    用于比较BigDecimal或者BigInteger
    例子:

    public void test1(@DecimalMax(value = "10", message = "最大为十") BigDecimal value)
    

    意思是这个值最大为10 message是超出范围的时候的错误信息

    第二个:

    javax.validation.constraints.DecimalMin

    和@DecimalMax 一样的使用方法,只是这个是最低 刚好相反

    第三第四个:

    javax.validation.constraints.Max

    javax.validation.constraints.Min

    例子:

    public void test3(@Max(value = 10) int value) {
    public void test4(@Min(value = 10,message = "最小为10") Integer value) 
    

    这个可以用于比较多种整形数值类型

    第五个:

    javax.validation.constraints.NotNull

    例子:

    public void test5(@NotNull int value)
    

    意思是校验的参数不可为null

    第六个:

    javax.validation.constraints.Size

    例子:

    public void test6(@Size(min = 10, max = 100,message = "最小10 最大100 不在10到100就报错") int value) 
    

    设置范围的。

    如上所述,虽然使用的是javax.validation.constraints下的注解,但是基本只支持注解的 value 和message 2个通用参数 别的都不支持! 因为我是手撕的校验 (validateMethod方法 逻辑可以自己调整)

    接下来是第一类:自定义对象的校验方法


    image.png

    对数组、List使用的@NotEmpty 是org.hibernate.validator.constraints.NotEmpty 用于校验数组、List不为空

    而且对象字段testClass2头上有2个注解

    @NotNull
    @Valid
    NotNull用于标记这个对象自身不可为null, Valid用于标记这个对象内部还需要校验

    所以testClass2List的2个注解也很容易理解了,就是标记List不可为null,且还要检查List里面每一个对象的字段

    image.png

    如图所示,如果是基本类型、包装类型、String、Decimal 等内建类型 可以使用如下注解

    而且是完整的功能,除了上面说的value 和message2个通用参数,还可以使用其他所有javax.validation.constraints支持的功能。因为自定义对象的校验是使用javax.validation提供的校验

    相关文章

      网友评论

        本文标题:dubbo rpc调用参数校验

        本文链接:https://www.haomeiwen.com/subject/kickbftx.html