美文网首页
tensorflow CNN图像分类中的数据shape变化

tensorflow CNN图像分类中的数据shape变化

作者: yingtaomj | 来源:发表于2017-10-27 17:46 被阅读150次

https://zhuanlan.zhihu.com/p/27288913的基础上,重写了tf.Graph。

    global_step = tf.Variable(0, trainable=False)
    # placeholder
    images = tf.placeholder(tf.float32, [BATCH_SIZE, 32, 32, 3], name='images')
    labels = tf.placeholder(tf.int32, (BATCH_SIZE,), name='labels')

    print("Done Initializing Training Placeholders")

labels不是one-hot模式,就是数字本身。
placeholder的第一维都是固定的batch_size。

    # Build a Graph that computes the logits predictions from the placeholder
    logits = CNN(images)

    # Calculate loss
    loss = cal_loss(logits, labels)

    # Build a Graph that trains the model with one batch of examples and
    # updates the model parameters.
    train_step = tf.train.GradientDescentOptimizer(0.1).minimize(loss)

logits的shape是(batch_size,10),是one-hot形式
cal_loss中,Logits的shape是(batch_size,10),而labels则是(batch_size,1),因此用的函数是tf.nn.sparse_softmax_cross_entropy_with_logits

训练部分:

    for step in range(1000):
        # Current batch number
        batch_nb = step % nb_batches

        # Current batch start and end indices
        start, end = utils.batch_indices(batch_nb, data_length, BATCH_SIZE)

        # Prepare dictionnary to feed the session with
        feed_dict = {images: X_train[start:end],
                     labels: y_train[start:end]}

        # Run training step
        _, loss_value = sess.run([train_step, loss], feed_dict=feed_dict)

        # Echo loss once in a while
        if step % 20 == 0:
            num_examples_per_step = BATCH_SIZE
            examples_per_sec = num_examples_per_step / duration
            sec_per_batch = float(duration)

            format_str = ('%s: step %d, loss = %.2f (%.1f examples/sec; %.3f '
                          'sec/batch)')
            print(format_str % (datetime.now(), step, loss_value,
                                examples_per_sec, sec_per_batch))

检测部分:

newbatch = math.ceil(1000 / BATCH_SIZE)
preds = np.zeros((1000, NUM_CLASS), dtype=np.float32)
# 检测数据有1000,分为64大小的部分循环检测
for cnt in range(0, int(newbatch + 1)):
      # Compute batch start and end indices
      start, end = utils.batch_indices(cnt, 1000, BATCH_SIZE)
      # Prepare feed dictionary
      feed_dict = {images: X_test[start:end]}
      preds[start:end, :] = sess.run([logits], feed_dict=feed_dict)[0]#取第一维

precision = accuracy(preds, y_test)
print('Precision of teacher after training: ' + str(precision))

训练步长设置为0.1,正确率达到60%
训练步长设置为0.05,正确率达到65%
链接:https://github.com/yingtaomj/cnn-classification

相关文章

网友评论

      本文标题:tensorflow CNN图像分类中的数据shape变化

      本文链接:https://www.haomeiwen.com/subject/npkdpxtx.html