美文网首页码农的世界人工智能程序员
JAVAEE与人工智能实战之--通过MNIST进行模型训练

JAVAEE与人工智能实战之--通过MNIST进行模型训练

作者: 山东大葱哥 | 来源:发表于2019-04-08 08:30 被阅读5次

    MNIST简介

    一个手写数字识别库,世界上最权威的,美国邮政系统开发的,手写内容是0-9的内容,手写内容采集于美国人口调查局的员工和高中生。包括6万张训练图片和1万张测试图片构成的,每张图片都是28*28大小,而且都是黑白色构成。

    MINIST实验包含了四个文件,其中train-images-idx3-ubyte是60000个图片样本,train-labels-idx1-ubyte是这60000个图片对应的数字标签,t10k-images-idx3-ubyte是用于测试的样本,t10k-labels-idx1-ubyte是测试样本对应的数字标签。

    我们以测试集中的一个图片为例来说明图片的存储形式:

    MNIST图片并不是传统意义上的png或者jpg格式的图片,因为png或者jpg的图片格式,会带有很多干扰信息(如:数据块,图片头,图片尾,长度等等),这些图片会被处理成很简易的数组,图片长度为28,宽度也为28,总像素为2828=784,在MNIST存储的就是一个长度为784的数组,数组中的每个值表示每个点的RGB值,其中黑色用0表示、白色用255表示。我们可以将数组转成2828的二维数组,如下图所示,可以看出这是一个表示的是数字5的图片。

    image.png

    如果把像素写成图片,图片是这样的:


    image.png

    通过MNIST训练模型

    在BP神经网络中, 层数、节点个数、学习速率、训练集、训练次数,都会影响到最终模型的泛化能力。因此,在设计模型时,节点的个数,学习速率的大小,以及训练次数都是需要考虑的。

    本实例中设置神经网络层数为3层,其中输入特征为784个,每层节点数分别为300、100、10个,学习速率设置为0.5,迭代周期为30,批量设置60个。通过训练该模型在MNIST测试集上的平均准确率为96.68 %左右。

    public static void main(String[] args) {
            //三层网络,各层节点数为784*300*10 输入特征 784个  隐藏层节点300个 输出层节点10个
            int[] nodeNum = {784, 300,100, 10};
            //周期被定义为向前和向后传播中所有批次的单次训练迭代。
            int epoch = 30;
            //每次批量的样本数
            int batchSize = 60;
            double learningRate=0.5;
            NetTrainAndTest.train(nodeNum, epoch, batchSize,learningRate);
        }
    

    对模型进行序列化

    为了“一次训练、多次使用”,我们对训练好的模型进行序列化存储,后续即可通过反序列化的方式读取恢复模型。

        /**
         * 通过序列化方式存储模型
         *
         * @param fileName 模型存放的文件名
         */
        public static <T> void saveModel(String fileName, T obj) {
            try (BufferedOutputStream bos = new BufferedOutputStream(new FileOutputStream(fileName));
                 ObjectOutputStream oos = new ObjectOutputStream(bos)) {
                oos.writeObject(obj);
            } catch (IOException e) {
                throw new RuntimeException(e);
            }
        }
    
        /**
         * 恢复模型
         *
         * @param fileName 模型持久化的存放位置 文件名
         *                 <p>
         *                 //@SuppressWarnings("unchecked")
         */
        public static <T> T restoreModel(String fileName) {
            try (BufferedInputStream bis = new BufferedInputStream(new FileInputStream(fileName));
                 ObjectInputStream ois = new ObjectInputStream(bis)) {
                return (T) ois.readObject();
            } catch (IOException | ClassNotFoundException e) {
                throw new RuntimeException(e);
            }
        }
    
    
    上一篇 JAVAEE与人工智能目录 [下一篇]

    相关文章

      网友评论

        本文标题:JAVAEE与人工智能实战之--通过MNIST进行模型训练

        本文链接:https://www.haomeiwen.com/subject/hvgabqtx.html