美文网首页
SpringBoot 自动建表

SpringBoot 自动建表

作者: 东方不喵 | 来源:发表于2019-01-07 20:20 被阅读46次

    使用MyBatis或者JDBCTemplate的时候,并不能自动创建数据库表,这样需要多花点时间进行数据表的构建。为了减省这一步骤,可以编写一个简易的 自动建表模板。
    GitHub https://github.com/oldguys/MultipleDataSourceDemo

    Step1: Maven引用依赖。使用jpa规范进行建表。

    <!-- 使用java规范进行设置  -->
    <dependency>
        <groupId>javax.persistence</groupId>
        <artifactId>persistence-api</artifactId>
        <version>1.0</version>
    </dependency>
    

    Step2: 编写抽象模板

    package com.oldguy.example.modules.common.services;
    
    import com.oldguy.example.modules.common.dao.entities.SqlTableObject;
    
    import java.util.List;
    import java.util.Map;
    
    /**
     * @Description: 用于实现不同方言的数据库
     * @Author: ren
     * @CreateTime: 2018-10-2018/10/23 0023 13:53
     */
    public interface TableFactory {
    
        String showTableSQL();
    
        String getDialect();
    
        Map<Class, String> getColumnType();
    
        Map<String, String> trainToDBSchema(List<SqlTableObject> sqlTableObjects);
    }
    
    

    Step3: 编写实现类,来生成不同方言的Scheme

    package com.oldguy.example.modules.common.services.impls;
    
    import com.oldguy.example.modules.common.dao.entities.SqlTableObject;
    import com.oldguy.example.modules.common.services.TableFactory;
    
    import java.util.Date;
    import java.util.HashMap;
    import java.util.List;
    import java.util.Map;
    
    /**
     * @Description: MySQL 方言实现类,用于生产Scheme
     * @author ren
     * @date 2018/12/20
     */
    public class MySQLTableFactory implements TableFactory {
    
        private static final Map<Class, String> columnType;
    
        static {
            columnType = new HashMap<>();
            columnType.put(Integer.class, "INT");
            columnType.put(Long.class, "BIGINT");
            columnType.put(String.class, "VARCHAR");
            columnType.put(Date.class, "DATETIME");
            columnType.put(Boolean.class, "TINYINT");
            columnType.put(Double.class, "DOUBLE");
        }
    
    
        @Override
        public String showTableSQL() {
            return "show tables";
        }
    
        @Override
        public String getDialect() {
            return "MySQL";
        }
    
        @Override
        public Map<Class, String> getColumnType() {
            return columnType;
        }
    
        @Override
        public Map<String, String> trainToDBSchema(List<SqlTableObject> sqlTableObjects) {
    
            Map<String, String> tableMap = new HashMap<>(sqlTableObjects.size());
    
            for (SqlTableObject obj : sqlTableObjects) {
                StringBuilder builder = new StringBuilder();
                builder.append("CREATE TABLE IF NOT EXISTS `" + obj.getTableName() + "` (").append("\n");
    
                for (int i = 0; i < obj.getColumns().size(); i++) {
    
                    SqlTableObject.Column column = obj.getColumns().get(i);
                    builder.append("`").append(column.getName()).append("` ");
    
                    if (column.getType().equals("VARCHAR")) {
                        if (column.getLength() == null) {
                            builder.append(column.getType()).append("(").append(255).append(")");
                        } else {
                            builder.append(column.getType()).append("(").append(column.getLength()).append(")");
                        }
                    } else {
                        builder.append(column.getType().toUpperCase());
                    }
    
                    if (column.isPrimaryKey()) {
                        builder.append(" PRIMARY KEY");
                        if (column.isAutoIncrement()) {
                            builder.append(" AUTO_INCREMENT");
                        }
                    }
    
                    if (!column.isNullable()) {
                        builder.append(" NOT NULL");
                    }
    
                    if(column.isUnique()){
                        builder.append(" UNIQUE");
                    }
    
                    if (i < obj.getColumns().size() - 1) {
                        builder.append(",");
                    }
    
                    builder.append("\n");
                }
    
                builder.append(") ENGINE=InnoDB DEFAULT CHARSET=utf8 ;").append("\n\n");
    
                if (tableMap.containsKey(obj.getTableName())) {
                    throw new RuntimeException(obj.getTableName() + " 表名重复。");
                } else {
                    tableMap.put(obj.getTableName(), builder.toString());
                }
            }
    
            return tableMap;
        }
    }
    

    Step4: 编写扫描注册器

    package com.oldguy.example.modules.common.services;
    
    
    
    import com.oldguy.example.modules.common.annotation.AssociateEntity;
    import com.oldguy.example.modules.common.annotation.Entity;
    import com.oldguy.example.modules.common.dao.entities.SqlTableObject;
    import com.oldguy.example.modules.common.services.impls.MySQLTableFactory;
    import com.oldguy.example.modules.common.utils.ClassUtils;
    import org.springframework.util.StringUtils;
    
    import javax.persistence.Column;
    import javax.persistence.GeneratedValue;
    import javax.persistence.GenerationType;
    import javax.persistence.Id;
    import java.lang.reflect.Field;
    import java.util.*;
    
    /**
     * @Description: 数据表注册器
     * @Author: ren
     * @CreateTime: 2018-10-2018/10/16 0016 10:33
     */
    public class DbRegister {
    
        private TableFactory tableFactory;
    
        public DbRegister() {
            this.tableFactory = new MySQLTableFactory();
        }
    
        public DbRegister(TableFactory tableFactory) {
            this.tableFactory = tableFactory;
        }
    
    
        /**
         * 编写数据库配置文件
         *
         * @param packageNames
         */
        public Map<String, String> registerClassToDB(String... packageNames) {
    
            if (packageNames.length == 0) {
                return Collections.emptyMap();
            }
    
            List<Class> classList = new ArrayList<>();
            for (String packageName : packageNames) {
                classList.addAll(ClassUtils.getClasses(packageName));
            }
    
            List<SqlTableObject> sqlTableObjects = new ArrayList<>();
            for (Class clazz : classList) {
                if (clazz.isAnnotationPresent(Entity.class) || clazz.isAnnotationPresent(AssociateEntity.class)) {
    
                    SqlTableObject obj = new SqlTableObject();
                    String tableName = "";
                    String preIndex = "";
    
                    if (clazz.isAnnotationPresent(Entity.class)) {
                        Entity annotation = (Entity) clazz.getAnnotation(Entity.class);
                        tableName = annotation.name();
                        preIndex = annotation.pre();
                    } else if(clazz.isAnnotationPresent(AssociateEntity.class)){
                        AssociateEntity annotation = (AssociateEntity) clazz.getAnnotation(AssociateEntity.class);
                        tableName = annotation.name();
                        preIndex = annotation.pre();
                    }
    
                    tableName = StringUtils.isEmpty(tableName) ? preIndex + formatTableName(clazz.getSimpleName()) : tableName;
                    obj.setTableName(tableName);
    
                    // 配置字段
                    List<Field> fields = new ArrayList<>();
                    getAllField(clazz, fields);
                    setTableColumns(obj, fields);
                    sqlTableObjects.add(obj);
                }
            }
    
            //转换成为SQL Schema
            return trainToDBSchema(sqlTableObjects);
        }
    
        /**
         * 转换成为SQLSchema
         *
         * @param sqlTableObjects
         */
        private Map<String, String> trainToDBSchema(List<SqlTableObject> sqlTableObjects) {
    
            if (null == tableFactory) {
                throw new RuntimeException("TableFactory 不能为空!");
            }
    
            return tableFactory.trainToDBSchema(sqlTableObjects);
        }
    
        /**
         * 设置表格字段
         *
         * @param obj
         * @param fields
         */
        private void setTableColumns(SqlTableObject obj, List<Field> fields) {
    
            List<SqlTableObject.Column> columnList = new ArrayList<>();
            for (Field field : fields) {
                if (tableFactory.getColumnType().containsKey(field.getType())) {
    
                    SqlTableObject.Column column = new SqlTableObject.Column();
    
                    if (field.isAnnotationPresent(Id.class)) {
                        column.setPrimaryKey(true);
                        if (field.isAnnotationPresent(GeneratedValue.class)) {
                            GeneratedValue annotation = field.getAnnotation(GeneratedValue.class);
                            if (annotation.strategy().equals(GenerationType.AUTO)) {
                                column.setAutoIncrement(true);
                            }
                        }
                    }
    
                    if (field.isAnnotationPresent(Column.class)) {
                        Column annotation = field.getAnnotation(Column.class);
    
                        if (!StringUtils.isEmpty(annotation.name())) {
                            column.setName(annotation.name());
                        } else {
                            column.setName(formatTableName(field.getName()));
                        }
    
                        if (!StringUtils.isEmpty(annotation.columnDefinition())) {
                            column.setType(annotation.columnDefinition());
                        } else {
                            column.setType(tableFactory.getColumnType().get(field.getType()));
                        }
    
                        column.setLength(annotation.length());
                        column.setUnique(annotation.unique());
                        column.setNullable(annotation.nullable());
                    } else {
                        column.setName(formatTableName(field.getName()));
                        column.setType(tableFactory.getColumnType().get(field.getType()));
                    }
                    columnList.add(column);
                }
            }
            obj.setColumns(columnList);
        }
    
    
        /**
         * 获取所有的 字段
         *
         * @param clazz
         * @param fields
         */
        private static void getAllField(Class clazz, List<Field> fields) {
            fields.addAll(Arrays.asList(clazz.getDeclaredFields()));
            if (!clazz.getSuperclass().equals(Object.class)) {
                getAllField(clazz.getSuperclass(), fields);
            }
        }
    
        /**
         * 驼峰转双峰
         *
         * @param name
         * @return
         */
        public static String formatTableName(String name) {
            StringBuilder formatResult = new StringBuilder();
            char[] upperCaseArrays = name.toUpperCase().toCharArray();
            char[] defaultArrays = name.toCharArray();
    
            for (int i = 0; i < upperCaseArrays.length; i++) {
                if (i == 0) {
                    formatResult.append(String.valueOf(defaultArrays[0]).toLowerCase());
                    continue;
                }
                if (defaultArrays[i] == upperCaseArrays[i]) {
                    formatResult.append("_" + String.valueOf(defaultArrays[i]).toLowerCase());
                } else {
                    formatResult.append(defaultArrays[i]);
                }
            }
    
            return formatResult.toString();
        }
    
    
        public void setTableFactory(TableFactory tableFactory) {
            this.tableFactory = tableFactory;
        }
    
        public TableFactory getTableFactory() {
            return tableFactory;
        }
    
    }
    
    

    Step5: 编写Configuration注册类

    package com.oldguy.example.configs;
    
    
    import com.oldguy.example.modules.common.services.DbRegister;
    import com.oldguy.example.modules.common.utils.Log4jUtils;
    import org.springframework.beans.factory.annotation.Qualifier;
    import org.springframework.beans.factory.annotation.Value;
    import org.springframework.context.annotation.Bean;
    import org.springframework.context.annotation.Configuration;
    import org.springframework.jdbc.core.JdbcTemplate;
    import org.springframework.util.StringUtils;
    
    import javax.annotation.Resource;
    import javax.sql.DataSource;
    import java.util.*;
    
    /**
     * @Description: 数据库注册类
     * @author ren
     * @date 2018/12/20
     */
    public class DbRegisterConfiguration {
    
        /**
         * 初始化数据库
         */
        public void initDB(JdbcTemplate jdbcTemplate,String typeAliasesPackage) {
    
            DbRegister dbRegister = new DbRegister();
            Map<String, String> tableMap = new HashMap<>();
            List<String> typeAliasesPackages = splitPackagesPath(typeAliasesPackage);
    
            for (String path : typeAliasesPackages) {
                tableMap.putAll(dbRegister.registerClassToDB(path));
            }
    
            if (!tableMap.keySet().isEmpty()) {
    
                List<Map<String, Object>> mapList = jdbcTemplate.queryForList(dbRegister.getTableFactory().showTableSQL());
                Set<String> tableNameSet = new HashSet<>();
                for (Map<String, Object> item : mapList) {
                    for (String key : item.keySet()) {
                        tableNameSet.add((String) item.get(key));
                    }
                }
    
                for (String key : tableMap.keySet()) {
                    if (!tableNameSet.contains(key)) {
                        Log4jUtils.getInstance(getClass()).info("未找到表[" + key + "],进行创建.");
                        String sql = tableMap.get(key);
                        if (sql.trim().length() > 0) {
                            jdbcTemplate.execute(sql);
                            Log4jUtils.getInstance(getClass()).info("\n\n" + sql);
                        }
                    } else {
                        Log4jUtils.getInstance(getClass()).info("表[" + key + "] 已存在");
                    }
                }
            }
        }
    
        private List<String> splitPackagesPath(String typeAliasesPackage) {
            List<String> paths = new ArrayList<>();
            String[] packagePaths = typeAliasesPackage.split(";");
            for (String path : packagePaths) {
                if (!StringUtils.isEmpty(path)) {
                    paths.add(path);
                }
            }
            return paths;
        }
    
    
    }
    
    

    以上模板基本构建完成。

    Step6: 开始配置;进行数据库注册,可以注册多个数据源的数据库

    1. 注册类
    package com.oldguy.example.configs;
    
    import com.oldguy.example.modules.common.utils.Log4jUtils;
    import org.springframework.beans.factory.annotation.Autowired;
    import org.springframework.beans.factory.annotation.Qualifier;
    import org.springframework.beans.factory.annotation.Value;
    import org.springframework.context.annotation.Bean;
    import org.springframework.context.annotation.Configuration;
    import org.springframework.jdbc.core.JdbcTemplate;
    
    import javax.annotation.PostConstruct;
    import javax.annotation.Resource;
    import javax.sql.DataSource;
    import java.io.IOException;
    
    /**
     * @author ren
     * @date 2018/12/20
     */
    @Configuration
    public class DemoConfiguration {
    
        @Resource(name = "test1DataSource")
        private DataSource test1dataSource;
    
        @Value("${test1.mybatis.type-aliases-package}")
        private String test1TypeAliasesPackage;
    
        @Resource(name = "test2DataSource")
        private DataSource test2dataSource;
        @Value("${test2.mybatis.type-aliases-package}")
        private String test2TypeAliasesPackage;
    
    
    
        @PostConstruct
        public void initData() throws IOException {
    
            Log4jUtils.getInstance(getClass()).info("初始化 数据库 -------------------------------");
            DbRegisterConfiguration dbConfiguration = new DbRegisterConfiguration();
    
            // 自动生成表结构 test1
            JdbcTemplate jdbcTemplate = new JdbcTemplate(test1dataSource);
            dbConfiguration.initDB(jdbcTemplate,test1TypeAliasesPackage);
    
            // 自动生成表结构 test2
            jdbcTemplate = new JdbcTemplate(test2dataSource);
            dbConfiguration.initDB(jdbcTemplate,test2TypeAliasesPackage);
    
        }
    }
    
    

    2.注册数据源

    package com.oldguy.example.configs;
    
    import com.alibaba.druid.pool.DruidDataSource;
    import org.springframework.boot.context.properties.ConfigurationProperties;
    import org.springframework.context.annotation.Bean;
    import org.springframework.context.annotation.Configuration;
    import org.springframework.core.io.ResourceLoader;
    
    
    /**
     * @author ren
     * @date 2018/12/20
     */
    @Configuration
    public class Test1DataSourceConfiguration extends AbstractMybatisConfiguration {
    
        @Bean(name = "test1DataSource")
        @ConfigurationProperties(prefix = "test1.datasource")
        public DruidDataSource test1DataSource() {
            return new DruidDataSource();
        }
    
        @Bean(name = "test2DataSource")
        @ConfigurationProperties(prefix = "test2.datasource")
        public DruidDataSource test2DataSource() {
            return new DruidDataSource();
        }
    
    }
    
    

    3.yaml配置文件

    test1:
      datasource:
        username: root
        password: root
        url: jdbc:mysql://127.0.0.1:3306/multiple_datasource1?useUnicode=true&characterEncoding=utf8&useSSL=false&allowMultiQueries=true
        driver-class-name: com.mysql.jdbc.Driver
        type: com.alibaba.druid.pool.DruidDataSource
      mybatis:
        mapper-locations: classpath:mappers/test1/*.xml
        type-aliases-package: com.oldguy.example.modules.test1.dao.entities;
        config-location: classpath:configs/myBatis-config.xml
    
    test2:
      datasource:
        username: root
        password: root
        url: jdbc:mysql://127.0.0.1:3306/multiple_datasource2?useUnicode=true&characterEncoding=utf8&useSSL=false&allowMultiQueries=true
        driver-class-name: com.mysql.jdbc.Driver
        type: com.alibaba.druid.pool.DruidDataSource
      mybatis:
        mapper-locations: classpath:mappers/test2/*.xml
        type-aliases-package: com.oldguy.example.modules.test2.dao.entities;
        config-location: classpath:configs/myBatis-config.xml
    
    

    这样就可以完整 自动建表配置类。

    引用: 包扫描类

    package com.oldguy.example.modules.common.utils;/**
     * Created by Administrator on 2018/10/16 0016.
     */
    
    import java.io.File;
    import java.io.FileFilter;
    import java.io.IOException;
    import java.net.JarURLConnection;
    import java.net.URL;
    import java.net.URLDecoder;
    import java.util.ArrayList;
    import java.util.Enumeration;
    import java.util.List;
    import java.util.jar.JarEntry;
    import java.util.jar.JarFile;
    
    /**
     * @author ren
     * @date 2018/12/20
     */
    public class ClassUtils {
    
        /**
         * 通过包名获取包内所有类
         *
         * @param pkg
         * @return
         */
        public static List<Class<?>> getAllClassByPackageName(Package pkg) {
            String packageName = pkg.getName();
            // 获取当前包下以及子包下所以的类
            List<Class<?>> returnClassList = getClasses(packageName);
            return returnClassList;
        }
    
        /**
         * 通过接口名取得某个接口下所有实现这个接口的类
         */
        public static List<Class<?>> getAllClassByInterface(Class<?> c) {
            List<Class<?>> returnClassList = null;
    
            if (c.isInterface()) {
                // 获取当前的包名
                String packageName = c.getPackage().getName();
                // 获取当前包下以及子包下所以的类
                List<Class<?>> allClass = getClasses(packageName);
                if (allClass != null) {
                    returnClassList = new ArrayList<Class<?>>();
                    for (Class<?> cls : allClass) {
                        // 判断是否是同一个接口
                        if (c.isAssignableFrom(cls)) {
                            // 本身不加入进去
                            if (!c.equals(cls)) {
                                returnClassList.add(cls);
                            }
                        }
                    }
                }
            }
    
            return returnClassList;
        }
    
        /**
         * 取得某一类所在包的所有类名 不含迭代
         */
        public static String[] getPackageAllClassName(String classLocation, String packageName) {
            // 将packageName分解
            String[] packagePathSplit = packageName.split("[.]");
            String realClassLocation = classLocation;
            int packageLength = packagePathSplit.length;
            for (int i = 0; i < packageLength; i++) {
                realClassLocation = realClassLocation + File.separator + packagePathSplit[i];
            }
            File packeageDir = new File(realClassLocation);
            if (packeageDir.isDirectory()) {
                String[] allClassName = packeageDir.list();
                return allClassName;
            }
            return null;
        }
    
        /**
         * 从包package中获取所有的Class
         *
         * @param packageName
         * @return
         */
        public static List<Class<?>> getClasses(String packageName) {
    
            // 第一个class类的集合
            List<Class<?>> classes = new ArrayList<>();
            // 是否循环迭代
            boolean recursive = true;
            // 获取包的名字 并进行替换
            String packageDirName = packageName.replace('.', '/');
            // 定义一个枚举的集合 并进行循环来处理这个目录下的things
            Enumeration<URL> dirs;
            try {
                dirs = Thread.currentThread().getContextClassLoader().getResources(packageDirName);
                // 循环迭代下去
                while (dirs.hasMoreElements()) {
                    // 获取下一个元素
                    URL url = dirs.nextElement();
                    // 得到协议的名称
                    String protocol = url.getProtocol();
                    // 如果是以文件的形式保存在服务器上
                    if ("file".equals(protocol)) {
                        // 获取包的物理路径
                        String filePath = URLDecoder.decode(url.getFile(), "UTF-8");
                        // 以文件的方式扫描整个包下的文件 并添加到集合中
                        findAndAddClassesInPackageByFile(packageName, filePath, recursive, classes);
                    } else if ("jar".equals(protocol)) {
                        // 如果是jar包文件
                        // 定义一个JarFile
                        JarFile jar;
                        try {
                            // 获取jar
                            jar = ((JarURLConnection) url.openConnection()).getJarFile();
                            // 从此jar包 得到一个枚举类
                            Enumeration<JarEntry> entries = jar.entries();
                            // 同样的进行循环迭代
                            while (entries.hasMoreElements()) {
                                // 获取jar里的一个实体 可以是目录 和一些jar包里的其他文件 如META-INF等文件
                                JarEntry entry = entries.nextElement();
                                String name = entry.getName();
                                // 如果是以/开头的
                                if (name.charAt(0) == '/') {
                                    // 获取后面的字符串
                                    name = name.substring(1);
                                }
                                // 如果前半部分和定义的包名相同
                                if (name.startsWith(packageDirName)) {
                                    int idx = name.lastIndexOf('/');
                                    // 如果以"/"结尾 是一个包
                                    if (idx != -1) {
                                        // 获取包名 把"/"替换成"."
                                        packageName = name.substring(0, idx).replace('/', '.');
                                    }
                                    // 如果可以迭代下去 并且是一个包
                                    if ((idx != -1) || recursive) {
                                        // 如果是一个.class文件 而且不是目录
                                        if (name.endsWith(".class") && !entry.isDirectory()) {
                                            // 去掉后面的".class" 获取真正的类名
                                            String className = name.substring(packageName.length() + 1, name.length() - 6);
                                            try {
                                                // 添加到classes
                                                classes.add(Class.forName(packageName + '.' + className));
                                            } catch (ClassNotFoundException e) {
                                                e.printStackTrace();
                                            }
                                        }
                                    }
                                }
                            }
                        } catch (IOException e) {
                            e.printStackTrace();
                        }
                    }
                }
            } catch (IOException e) {
                e.printStackTrace();
            }
    
            return classes;
        }
    
        /**
         * 以文件的形式来获取包下的所有Class
         *
         * @param packageName
         * @param packagePath
         * @param recursive
         * @param classes
         */
        private static void findAndAddClassesInPackageByFile(String packageName, String packagePath, final boolean recursive, List<Class<?>> classes) {
            // 获取此包的目录 建立一个File
            File dir = new File(packagePath);
            // 如果不存在或者 也不是目录就直接返回
            if (!dir.exists() || !dir.isDirectory()) {
                return;
            }
            // 如果存在 就获取包下的所有文件 包括目录
            File[] dirfiles = dir.listFiles(new FileFilter() {
                // 自定义过滤规则 如果可以循环(包含子目录) 或则是以.class结尾的文件(编译好的java类文件)
                @Override
                public boolean accept(File file) {
                    return (recursive && file.isDirectory()) || (file.getName().endsWith(".class"));
                }
            });
            // 循环所有文件
            for (File file : dirfiles) {
                // 如果是目录 则继续扫描
                if (file.isDirectory()) {
                    findAndAddClassesInPackageByFile(packageName + "." + file.getName(), file.getAbsolutePath(), recursive, classes);
                } else {
                    // 如果是java类文件 去掉后面的.class 只留下类名
                    String className = file.getName().substring(0, file.getName().length() - 6);
                    try {
                        // 添加到集合中去
                        classes.add(Class.forName(packageName + '.' + className));
                    } catch (ClassNotFoundException e) {
                        e.printStackTrace();
                    }
                }
            }
        }
    
    }
    
    

    到此完成了 多数据源 自动建表模板搭建。
    代码可以参考 GitHub https://github.com/oldguys/MultipleDataSourceDemo

    相关文章

      网友评论

          本文标题:SpringBoot 自动建表

          本文链接:https://www.haomeiwen.com/subject/umeprqtx.html