美文网首页
tensorflow--cifar10数据集

tensorflow--cifar10数据集

作者: tu7jako | 来源:发表于2020-06-09 22:18 被阅读0次

    Cifar10数据集有6w张图片,每张图片有32行32列像素点的红绿蓝三通道数据,其中5w张十分类彩色图片用于训练,1w张用于测试。
    十分类分别是:


    cifar10.png

    导入数据集:

    cifar10 = tf.keras.datasets.cifar10
    (x_train, y_train), (x_test, y_test) = cifar10.load_data()
    

    搭建一个一层卷积、两侧全连接的网络来训练cifar10数据集:

    import tensorflow as tf
    from tensorflow.keras import Model
    from tensorflow.keras.layers import Conv2D, BatchNormalization, 
        Activation, MaxPool2D, Dropout, Flatten, Dense
    import matplotlib.pyplot as plt
    
    import os
    
    
    # 加载数据集
    cifar10 = tf.keras.datasets.cifar10
    (x_train, y_train), (x_test, y_test) = cifar10.load_data()
    x_train, x_test = x_train / 255.0, x_test / 255.0
    
    #  构建神经网路
    class BaseLine(Model):
        def __init__(self):
            super(BaseLine, self).__init__()
            #  一层卷积(CBAPD)
            self.c1 = Conv2D(filters=6, kernel_size=(5, 5), padding="same")
            self.b1 = BatchNormalization()
            self.a1 = Activation("relu")
            self.p1 = MaxPool2D(pool_size=(2, 2), strides=2, padding="same")
            self.d1 = Dropout(0.2)
    
            # 两层全连接
            self.flatten = Flatten()
            self.f1 = Dense(128, activation="relu")
            self.d2 = Dropout(0.2)
            self.f2 = Dense(10, activation="softmax")
    
        # 完成神经网路的前向传播
        def call(self, x):
            x = self.c1(x)
            x = self.b1(x)
            x = self.a1(x)
            x = self.p1(x)
            x = self.d1(x)
            x = self.flatten(x)
            x = self.f1(x)
            x = self.d2(x)
            y = self.f2(x)
            return y
    
    
    model = BaseLine()
    
    # 配置训练方法
    model.compile(
        optimizer="adam",
        loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=False),
        metrics=["sparse_categorical_accuracy"]
    )
    
    # 断点续训,读取模型
    checkpoint_save_path = "cifar10/BaseLine.ckpt"
    if os.path.exists(checkpoint_save_path + ".index"):
        print("*******load the model******")
        model.load_weights(checkpoint_save_path)
    cp_callback = tf.keras.callbacks.ModelCheckpoint(
        filepath=checkpoint_save_path,
        save_weights_only=True,
        save_best_only=True
    )
    
    # 训练模型
    history = model.fit(x_train, y_train, batch_size=32, epochs=5, validation_data=(x_test, y_test),
                        validation_freq=1, callbacks=[cp_callback])
    
    # 打印网络结构和参数
    model.summary()
    
    # 写入参数
    with open("cifar10_weights.txt", "w") as f:
        for v in model.trainable_variables:
            f.write(str(v.name) + "\n")
            f.write(str(v.shape) + "\n")
            f.write(str(v.numpy()) + "\n")
    
    
    # 显示训练和预测的acc、loss曲线
    acc = history.history["sparse_categorical_accuracy"]
    val_acc = history.history["val_sparse_categorical_accuracy"]
    loss = history.history["loss"]
    val_loss = history.history["val_loss"]
    plt.subplot(1, 2, 1)
    plt.plot(acc, label="train acc")
    plt.plot(val_acc, label="validation acc")
    plt.title("train & validation acc")
    plt.legend()
    plt.subplot(1, 2, 2)
    plt.plot(loss, label="train loss")
    plt.plot(val_loss, label="validation loss")
    plt.title("train & validation loss")
    plt.legend()
    plt.show()
    
    

    打印结果:(有省略)

    _________________________________________________________________
    flatten (Flatten)            multiple                  0         
    _________________________________________________________________
    dense (Dense)                multiple                  196736    
    _________________________________________________________________
    dropout_1 (Dropout)          multiple                  0         
    _________________________________________________________________
    dense_1 (Dense)              multiple                  1290      
    =================================================================
    Total params: 198,506
    Trainable params: 198,494
    Non-trainable params: 12
    

    绘图结果:


    myplot.png

    相关文章

      网友评论

          本文标题:tensorflow--cifar10数据集

          本文链接:https://www.haomeiwen.com/subject/xagetktx.html