美文网首页
【TF1:全连接层DEMO】

【TF1:全连接层DEMO】

作者: 唯师默蓝 | 来源:发表于2020-02-20 00:38 被阅读0次
    import matplotlib as mpl
    import matplotlib. pyplot as plt
    import numpy as np
    import sklearn
    import pandas as pd
    import os
    import sys
    import time
    import tensorflow as tf
    from tensorflow import keras
    
    
    fashion_mnist=keras.datasets.fashion_mnist
    (x_train_all,y_train_all),(x_test,y_test)=fashion_mnist.load_data()
    
    x_valid,x_train=x_train_all[:5000],x_train_all[5000:]
    y_valid,y_train=y_train_all[:5000],y_train_all[5000:]
    
    print(x_valid.shape,y_valid.shape)
    print(x_train.shape,y_train.shape)
    print(x_test.shape,y_test.shape)
    
    print(np.max(x_train),np.min(x_train))
    
    from sklearn.preprocessing import StandardScaler
    scaler=StandardScaler()
    x_train_scaled=scaler.fit_transform(x_train.astype(np.float32).reshape(-1,1)).reshape(-1,28*28)
    x_valid_scaled=scaler.transform(x_valid.astype(np.float32).reshape(-1,1)).reshape(-1,28*28)
    x_test_scaled=scaler.transform(x_test.astype(np.float32).reshape(-1,1)).reshape(-1,28*28)
    
    hidden_units=[100,100]
    class_num=10
    # 创建x,y的占位符
    # 构建图 - 运行图,这个过程中看不到数据,所以需要建立占位符顶替数据的位置
    x=tf.placeholder(tf.float32,[None,28*28])
    y=tf.placeholder(tf.int64,[None])
    
    # 创建临时变量
    input_for_next_layer=x
    for hidden_unit in hidden_units:
        # 对于循环中的每次迭代,都需要构建一个新的层次
        # 把输出赋给input_for_next_layer
        input_for_next_layer = tf.layers.dense(input_for_next_layer,hidden_unit,activation=tf.nn.relu)
    # 输出层
    # logits是最后一层隐藏层的输出*W的输出,之后会经softmax得一个prob
    logits = tf.layers.dense(input_for_next_layer,class_num)
    # 计算损失
    # 1、 logits->softmax->prob
    # 2、 labels -> one-hot
    # 3、 calculate cross_entropy
    loss=tf.losses.sparse_softmax_cross_entropy(labels=y,logits=logits)
    # prediction是logits最大的那个值,最大的索引,取索引
    prediction=tf.argmax(logits,1)
    # 正确地计算prediction
    # correct_prediction = [0(代表错误),1(代表正确)]
    correct_prediction=tf.equal(prediction,y)
    # 求平均,得accuracy,先把correct_prediction转为float
    accuracy=tf.reduce_mean(tf.cast(correct_prediction, tf.float64))
    
    # train_op用来训练网络,运行一次,训练一次
    train_op = tf.train.AdamOptimizer(1e-3).minimize(loss)
    
    # Session
    # 创建对象
    init=tf.global_variables_initializer()
    batch_size=20
    epochs=10
    # 训练次数等于所有样本数整除batch_size
    train_steps_per_epoch=x_train.shape[0] // batch_size
    valid_steps=x_valid.shape[0]//batch_size
    
    def eval_with_sess(sess, x, y, accuracy, images, labels, batch_size):
        eval_steps = images.shape[0] // batch_size
        eval_accuracies = []
        for step in range(eval_steps):
            batch_data = images[step * batch_size : (step+1) * batch_size]
            batch_label = labels[step * batch_size : (step+1) * batch_size]
            accuracy_val = sess.run(accuracy,
                                    feed_dict = {
                                        x: batch_data,
                                        y: batch_label
                                    })
            eval_accuracies.append(accuracy_val)
        return np.mean(eval_accuracies)
    
    with tf.Session() as sess:
        sess.run(init) # 运行初始化对象,此时图才构建
        # 接下来用sess.run一步步调用train_op来训练网络
        for epoch in range(epochs):
            for step in range(train_steps_per_epoch):
                batch_data = x_train_scaled[
                             step * batch_size: (step + 1) * batch_size]
                batch_label = y_train[
                              step * batch_size: (step + 1) * batch_size]
                loss_val, accuracy_val, _ = sess.run(
                    [loss, accuracy, train_op],
                    feed_dict={
                        x: batch_data,
                        y: batch_label
                    })
                print('\r[Train] epoch: %d, step: %d, loss: %3.5f, accuracy: %2.2f' % (
                    epoch, step, loss_val, accuracy_val), end="")
            valid_accuracy = eval_with_sess(sess, x, y, accuracy,
                                            x_valid_scaled, y_valid,
                                            batch_size)
            print("\t[Valid] acc: %2.2f" % (valid_accuracy))
    

    相关文章

      网友评论

          本文标题:【TF1:全连接层DEMO】

          本文链接:https://www.haomeiwen.com/subject/vzhtqhtx.html