美文网首页
spring 自定义注解、扫描器实现

spring 自定义注解、扫描器实现

作者: keetone | 来源:发表于2023-03-10 23:07 被阅读0次

    像jpa,mybatis,feign等框架在使用时都一个共同的特点,那就是只需要写接口就行,并没有具体的实现类。

    jpa 按照一种约定的格式去写方法名并不需要写实现类就能实现功能

    mybatis 在xml文件里面写上方法名对应的sql或者用注解写sql,也没有写mapper的实现类就能实现数据库查询功能

    feign 只需要写接口然后加上对应的注解就能实现http请求功能

    他们都是通过自定义代理实现的,即按照自己的方式去编写 FactoryBean 里面getObject方法的逻辑,生产代理对象,从而实现个性化功能。

    实现自定义注解的步骤

    要实现自定义注解,让它标记的类自动生成bean并且可以像普通的component 一样注入到其他bean中,那么要经历几个步骤呢,就像把大象装进冰箱一样,1开门,2装大象,3关门

    1. 编写自定义注解
    2. 编写接口
    3. 在接口上添加自定义注解
    4. 在程序启动的过程中找到有自定义注解的接口
    5. 为这些接口生成代理对象
    6. 将代理对象注册到spring的bean容器中

    可以将1,2,3一起看作是大象,4和5看作是打开冰箱门,6是把大象装进冰箱
    而这里关键的两步是 如何找到有自定义注解的接口和生成代理对象的bean

    这里提供两种实现自定义注解的方式,也是我在实际工作中为了实现功能而一步一步摸索出来的(呸......网上搜的)
    一种是通过实现 BeanDefinitionRegistryPostProcessor 接口
    另一种则是通过实现扫描器ClassPathBeanDefinitionScanner接口

    下面我将通过一个在工作中遇到的实际问题来展开

    场景

    业务系统和数据中台进行对接,数据中台作为所有数据的数据源头,向各个业务系统提供数据查询服务,因此业务系统也从连接mysql查数据变成了连接数据中台查数据,但是这个数据中台又没有实现jdbc协议,因此就无法通过像mysql那样使用orm框架。数据中台只提供了相关的http接口,可以支持传入sql语句,然后返回对应的数据。

    初步需求

    上面为了还原实际工作场景,描述就比较啰嗦,稍微提炼一下,形成需求:

    程序生成sql,然后将sql作为参数发送http请求。
    本质上就是一个http请求

    那在实际的开发过程中,第一次对接的时候,都是直接在业务代码里面用字符串拼接sql的(那sql是相当动态),然后拼好了sql后也不打印一下,最后发送http请求的时候也不打印(可真是够了......),每次一排查问题的时候那sql简直没法查,相当的人类不友好,无法阅读。

    思考一下

    既然是要写sql,那能不能像mybatis那样,把sql集中写在一个文件里面,它可是专业干这事儿的,这样sql的可读性不就大大的提高了么。
    那要怎么做呢?
    既然提到了mybatis,那我们就照虎画猫,照葫芦画瓢。我们也照着mybatis搞一套表面上看上去类似的,实际上又不一样的。sql写在xml里面,但是最后是走的http接口获取数据
    这就引发了更深层次的需求

    额外的需求

    定义一个接口,写一个方法,然后将sql写到xml文件里面,最终实现只需要调用接口的方法,就可以自动获取对应的sql并通过发送http请求去获取数据。
    为了实现这个额外的需求,我们需要干两件事儿。

    1. 需要有自定义的代理(这个就需要用到文章标题所提到的知识点了) ,这个是用来干sql解析,发送http请求获取数据逻辑的(本文这里就不写发送http请求部分代码了)
    2. 需要一个sql解析器(这个mybatis有现成的,直接拿过来用就好了(典型的拿来主义),还实现个锤锤哦)

    自定义注解类

    import java.lang.annotation.*;
    
    @Documented
    @Retention(RetentionPolicy.RUNTIME)
    @Target({ElementType.TYPE})
    public @interface Keeton {
    }
    
    

    自定义FactoryBean创建代理类

    我这里要实现的功能是:通过接口的方法名获取xml里面对应的sql,然后将sql语句作为参数通过http请求发送到数据中台的接口

    import org.apache.ibatis.builder.xml.XMLMapperBuilder;
    import org.apache.ibatis.io.Resources;
    import org.apache.ibatis.mapping.BoundSql;
    import org.apache.ibatis.mapping.MappedStatement;
    import org.apache.ibatis.mapping.ParameterMapping;
    import org.apache.ibatis.mapping.ParameterMode;
    import org.apache.ibatis.reflection.MetaObject;
    import org.apache.ibatis.session.Configuration;
    import org.apache.ibatis.type.TypeHandlerRegistry;
    import org.jetbrains.annotations.NotNull;
    import org.springframework.beans.factory.FactoryBean;
    
    import java.io.IOException;
    import java.io.InputStream;
    import java.lang.reflect.*;
    import java.text.SimpleDateFormat;
    import java.util.*;
    
    public class MyFactoryBean<T> implements FactoryBean<T> {
        private final ThreadLocal<SimpleDateFormat> dateTimeFormatter = ThreadLocal.withInitial(() -> new SimpleDateFormat("yyyy-MM-dd HH:mm:ss"));
        private final Class<T>                      interfaceType;
    
        public MyFactoryBean(Class<T> interfaceType) {
            this.interfaceType = interfaceType;
        }
    
        @Override
        public T getObject() throws Exception {
            return (T) Proxy.newProxyInstance(
                    interfaceType.getClassLoader(),
                    new Class[]{interfaceType},
                    createProxy(interfaceType));
        }
    
        @Override
        public Class<?> getObjectType() {
            return interfaceType;
        }
    
    
        private InvocationHandler createProxy(Class<T> interfaceType) {
            return (proxy, method, args) -> {
                if (Object.class.equals(method.getDeclaringClass())) {
                    return method.invoke(this, args);
                }
    
                // 返回值类型
                Class<?> returnType        = method.getReturnType();
                System.out.println("returnType = " + returnType);
    
                // 处理sql
                String sql = handlerSql(interfaceType, method, args);
                // todo 发送http请求,具体的代码就步贴了, 用new一个实例代替
                return returnType.newInstance();
            };
        }
    
        private String handlerSql(Class interfaceType, @NotNull Method method, Object[] args) {
            // 这里直接调用mybatis里面sql解析相关的方法,生成sql。由于是直接使用的mybatis的功能,所以是支持动态sql的
            MappedStatement mappedStatement = getMappedStatement(interfaceType, method.getName());
            BoundSql        boundSql        = mappedStatement.getBoundSql(args);
            Map<String, Object> kv  = kv(method, args);
            String              sql = formatSql(configuration, boundSql, kv);
            System.out.println("sql = " + sql);
            return sql;
        }
    
        private Map<String, Object> kv(@NotNull Method method, Object[] args) {
            TreeMap<String, Object> paramMap       = new TreeMap<>();
            Parameter[]             parameters     = method.getParameters();
            int                     parameterCount = method.getParameterCount();
            for (int i = 0; i < parameterCount; i++) {
                String key   = parameters[i].getName();
                Object value = args[i];
                paramMap.put(key, value);
            }
            return paramMap;
        }
    
        // 这玩意儿可以写成一个单利,我这里就偷个懒,直接new了
        Configuration configuration = new Configuration();
    
        private MappedStatement getMappedStatement(Class interfaceType, String methodName) {
            String      resource = getXmlPath(interfaceType);
            InputStream inputStream;
            try {
                inputStream = Resources.getResourceAsStream(resource);
            } catch (IOException e) {
                throw new RuntimeException(e);
            }
            XMLMapperBuilder builder = new XMLMapperBuilder(inputStream, configuration, resource, configuration.getSqlFragments());
            builder.parse();
    
            return configuration.getMappedStatement(methodName);
        }
    
        private String getXmlPath(Class interfaceType) {
            String baseDir = "xml文件的路径";
            return baseDir + interfaceType.getSimpleName() + ".xml";
        }
    
        // 这一段代码是从网上搬过来的
        private String formatSql(Configuration configuration, BoundSql boundSql, Map<String, Object> args) {
            String sql = boundSql.getSql();
            sql = beautifySql(sql);
            Object                 parameterObject     = boundSql.getParameterObject();
            List<ParameterMapping> parameterMappings   = boundSql.getParameterMappings();
            TypeHandlerRegistry    typeHandlerRegistry = configuration.getTypeHandlerRegistry();
    
            List<String> parameters = new ArrayList<>();
            if (parameterMappings != null) {
                MetaObject metaObject = args == null ? null : configuration.newMetaObject(args);
                for (ParameterMapping parameterMapping : parameterMappings) {
                    if (parameterMapping.getMode() != ParameterMode.OUT) {
                        //  参数值
                        Object value;
                        String propertyName = parameterMapping.getProperty();
                        //  获取参数名称
                        if (boundSql.hasAdditionalParameter(propertyName)) {
                            // 获取参数值
                            value = boundSql.getAdditionalParameter(propertyName);
                        } else if (parameterObject == null) {
                            value = null;
                        } else if (typeHandlerRegistry.hasTypeHandler(parameterObject.getClass())) {
                            // 如果是单个值则直接赋值
                            value = parameterObject;
                        } else {
                            value = metaObject == null ? null : metaObject.getValue(propertyName);
                        }
    
                        if (value instanceof Number) {
                            parameters.add(String.valueOf(value));
                        } else {
                            StringBuilder builder = new StringBuilder();
                            builder.append("'");
                            if (value instanceof Date) {
                                builder.append(dateTimeFormatter.get().format((Date) value));
                            } else if (value instanceof String) {
                                builder.append(value);
                            }
                            builder.append("'");
                            parameters.add(builder.toString());
                        }
                    }
                }
            }
    
            for (String value : parameters) {
                sql = sql.replaceFirst("\\?", value);
            }
            return sql;
        }
    
        public static String beautifySql(String sql) {
            sql = sql.replaceAll("[\\s\n ]+", " ");
            return sql;
        }
    }
    
    

    方案一,实现 BeanDefinitionRegistryPostProcessor 接口

    
    import com.keeton.spring.custom.annotation.Keeton;
    import com.keeton.spring.custom.annotation.MyFactoryBean;
    import org.springframework.beans.BeansException;
    import org.springframework.beans.factory.config.ConfigurableListableBeanFactory;
    import org.springframework.beans.factory.support.BeanDefinitionBuilder;
    import org.springframework.beans.factory.support.BeanDefinitionRegistry;
    import org.springframework.beans.factory.support.BeanDefinitionRegistryPostProcessor;
    import org.springframework.beans.factory.support.GenericBeanDefinition;
    import org.springframework.context.ApplicationContext;
    import org.springframework.context.ApplicationContextAware;
    import org.springframework.context.ResourceLoaderAware;
    import org.springframework.core.env.PropertyResolver;
    import org.springframework.core.io.Resource;
    import org.springframework.core.io.ResourceLoader;
    import org.springframework.core.io.support.ResourcePatternResolver;
    import org.springframework.core.io.support.ResourcePatternUtils;
    import org.springframework.core.type.classreading.CachingMetadataReaderFactory;
    import org.springframework.core.type.classreading.MetadataReader;
    import org.springframework.core.type.classreading.MetadataReaderFactory;
    import org.springframework.stereotype.Component;
    import org.springframework.util.ClassUtils;
    
    import java.io.IOException;
    import java.util.LinkedHashSet;
    import java.util.Set;
    
    @Component
    public class MyBeanDefinitionRegistryPostProcessor implements BeanDefinitionRegistryPostProcessor, ResourceLoaderAware, ApplicationContextAware {
    
    
        @Override
        public void postProcessBeanDefinitionRegistry(BeanDefinitionRegistry registry) throws BeansException {
            //这里一般我们是通过反射获取需要代理的接口的clazz列表
            //比如判断包下面的类,或者通过某注解标注的类等等
            Set<Class<?>> beanClazzs = scannerPackages("com.keeton.spring.custom.annotation");
            for (Class beanClazz : beanClazzs) {
                BeanDefinitionBuilder builder = BeanDefinitionBuilder.genericBeanDefinition(beanClazz);
                GenericBeanDefinition definition = (GenericBeanDefinition) builder.getRawBeanDefinition();
    
                //在这里,我们可以给该对象的属性注入对应的实例。
                //比如mybatis,就在这里注入了dataSource和sqlSessionFactory,
                // 注意,如果采用definition.getPropertyValues()方式的话,
                // 类似definition.getPropertyValues().add("interfaceType", beanClazz);
                // 则要求在FactoryBean(本应用中即ServiceFactory)提供setter方法,否则会注入失败
                // 如果采用definition.getConstructorArgumentValues(),
                // 则FactoryBean中需要提供包含该属性的构造方法,否则会注入失败
                definition.getConstructorArgumentValues().addGenericArgumentValue(beanClazz);
    
                //注意,这里的BeanClass是生成Bean实例的工厂,不是Bean本身。
                // FactoryBean是一种特殊的Bean,其返回的对象不是指定类的一个实例,
                // 其返回的是该工厂Bean的getObject方法所返回的对象。
                definition.setBeanClass(MyFactoryBean.class);
    
                //这里采用的是byType方式注入,类似的还有byName等
                definition.setAutowireMode(GenericBeanDefinition.AUTOWIRE_BY_TYPE);
                registry.registerBeanDefinition(beanClazz.getSimpleName(), definition);
            }
    
        }
    
        private static final String DEFAULT_RESOURCE_PATTERN = "**/*.class";
        private ResourcePatternResolver resourcePatternResolver;
        private MetadataReaderFactory metadataReaderFactory;
    
        private Set<Class<?>> scannerPackages(String basePackage) {
            Set<Class<?>> set = new LinkedHashSet<>();
            String packageSearchPath = ResourcePatternResolver.CLASSPATH_ALL_URL_PREFIX +
                    resolveBasePackage(basePackage) + '/' + DEFAULT_RESOURCE_PATTERN;
            try {
                Resource[] resources = this.resourcePatternResolver.getResources(packageSearchPath);
                for (Resource resource : resources) {
                    if (resource.isReadable()) {
                        MetadataReader metadataReader = this.metadataReaderFactory.getMetadataReader(resource);
                        String className = metadataReader.getClassMetadata().getClassName();
                        Class<?> clazz;
                        try {
                            clazz = Class.forName(className);
                            if (clazz.isAnnotationPresent(Keeton.class))
                                set.add(clazz);
                        } catch (ClassNotFoundException e) {
                            e.printStackTrace();
                        }
                    }
                }
            } catch (IOException e) {
                e.printStackTrace();
            }
            return set;
        }
    
        protected String resolveBasePackage(String basePackage) {
            return ClassUtils.convertClassNameToResourcePath(this.getEnvironment().resolveRequiredPlaceholders(basePackage));
        }
    
        private PropertyResolver getEnvironment() {
            return applicationContext.getEnvironment();
        }
    
        @Override
        public void postProcessBeanFactory(ConfigurableListableBeanFactory beanFactory) throws BeansException {
    
        }
    
        private ApplicationContext applicationContext;
    
        @Override
        public void setResourceLoader(ResourceLoader resourceLoader) {
            this.resourcePatternResolver = ResourcePatternUtils.getResourcePatternResolver(resourceLoader);
            this.metadataReaderFactory = new CachingMetadataReaderFactory(resourceLoader);
        }
    
        @Override
        public void setApplicationContext(ApplicationContext applicationContext) throws BeansException {
            this.applicationContext = applicationContext;
        }
    
    }
    
    

    方案二,实现 ClassPathBeanDefinitionScanner 接口

    import org.springframework.context.annotation.Import;
    
    import java.lang.annotation.*;
    
    @Documented
    @Retention(RetentionPolicy.RUNTIME)
    @Target({ElementType.TYPE})
    @Import(KeetonScannerRegistrar.class)
    public @interface KeetonScan {
    
        String[] value() default {};
    }
    
    
    import org.springframework.beans.factory.annotation.AnnotatedBeanDefinition;
    import org.springframework.beans.factory.config.BeanDefinition;
    import org.springframework.beans.factory.config.BeanDefinitionHolder;
    import org.springframework.beans.factory.support.BeanDefinitionRegistry;
    import org.springframework.context.annotation.ClassPathBeanDefinitionScanner;
    import org.springframework.core.type.filter.AnnotationTypeFilter;
    
    import java.util.Set;
    
    public class KeetonScanner extends ClassPathBeanDefinitionScanner {
        public KeetonScanner(BeanDefinitionRegistry registry) {
            super(registry, false);
            addIncludeFilter(new AnnotationTypeFilter(Keeton.class));
        }
    
        @Override
        protected Set<BeanDefinitionHolder> doScan(String... basePackages) {
            Set<BeanDefinitionHolder> beanDefinitionHolders = super.doScan(basePackages);
            for (BeanDefinitionHolder beanDefinitionHolder : beanDefinitionHolders) {
                BeanDefinition definition    = beanDefinitionHolder.getBeanDefinition();
                String         beanClassName = definition.getBeanClassName();
                definition.setBeanClassName(MyFactoryBean.class.getName());
                try {
                    definition.getConstructorArgumentValues().addGenericArgumentValue(Class.forName(beanClassName));
                } catch (ClassNotFoundException e) {
                    throw new RuntimeException(e);
                }
            }
            return beanDefinitionHolders;
        }
    
    
        @Override
        protected boolean isCandidateComponent(AnnotatedBeanDefinition beanDefinition) {
            return beanDefinition.getMetadata().isIndependent() && beanDefinition.getMetadata().isInterface();
        }
    }
    
    
    import org.springframework.beans.factory.support.BeanDefinitionRegistry;
    import org.springframework.context.annotation.ImportBeanDefinitionRegistrar;
    import org.springframework.core.annotation.AnnotationAttributes;
    import org.springframework.core.type.AnnotationMetadata;
    
    public class KeetonScannerRegistrar implements ImportBeanDefinitionRegistrar {
        @Override
        public void registerBeanDefinitions(AnnotationMetadata importingClassMetadata, BeanDefinitionRegistry registry) {
            AnnotationAttributes annoAttrs = AnnotationAttributes.fromMap(importingClassMetadata
                    .getAnnotationAttributes(KeetonScan.class.getName()));//获取@Import所标注的注解信息
            KeetonScanner scanner = new KeetonScanner(registry);
            // AnnotationAttributes有获取各种注解信息的方法
            scanner.doScan(annoAttrs.getStringArray("value"));
        }
    }
    
    
    import java.util.Map;
    
    @Keeton
    public interface TestMapper {
        Map<String, Object> selectByPrimaryKey(String name, Integer gender);
    
    }
    
    
    <?xml version="1.0" encoding="UTF-8"?>
    <!DOCTYPE mapper PUBLIC "-//mybatis.org//DTD Mapper 3.0//EN" "http://mybatis.org/dtd/mybatis-3-mapper.dtd">
    <mapper namespace="TestMapper">
        <sql id="Base_Column_List">
            id, register_id, user_name, gender, birth_time, work_time, original_position, appoint_time,
            execute_time, approver, create_time, create_by, update_time, update_by, bef_change_title, aft_grade,
            aft_grade_wages, add_wages
        </sql>
    
        <select id="selectByPrimaryKey" resultType="java.util.Map">
            select
            <include refid="Base_Column_List"/>
            from yl_audit_five_years_promotion
            where user_name = #{name} and gender = #{gender}
        </select>
    </mapper>
    

    方案二的使用方式

    启动类

    @KeetonScan("com.keeton.spring.custom.annotation")
    @SpringBootApplication
    public class DefaultProxyCreatorApplication {
        public static void main(String[] args) {
            SpringApplication.run(DefaultProxyCreatorApplication.class, args);
        }
    }
    

    单元测试

    import org.junit.jupiter.api.Test;
    import org.springframework.beans.factory.annotation.Autowired;
    import org.springframework.boot.test.context.SpringBootTest;
    
    
    @KeetonScan("com.keeton.spring.custom.annotation")
    @SpringBootTest
    class DefaultProxyCreatorApplicationTest {
    
        @Autowired
        private TestMapper testService;
    
        @Test
        void test1() {
            testService.selectByPrimaryKey("张三", 13);
        }
    }
    
    
    图片.png

    相关文章

      网友评论

          本文标题:spring 自定义注解、扫描器实现

          本文链接:https://www.haomeiwen.com/subject/cnytrdtx.html