接前文:TensorFlow学习笔记:Retrain Inception_v3(一)
5. 代码改写
通过阅读源码可以发现,程序的超参数都是通过命令行传入,当然,每个命令行参数都含有默认值。如果我们想要直接运行,稍微改动一下代码的最后一个部分中命令行参数的默认地址即可。将模型下载、数据集、文件存贮地址修改为特定的地址:
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument(
'--image_dir',
type=str,
default='C:/Users/Dexter/Documents/ML_files/171009_Retrain_with_Inception_v3/flower_photos',
#default='',
help='Path to folders of labeled images.'
)
parser.add_argument(
'--output_graph',
type=str,
#default='/tmp/output_graph.pb',
default='C:/Users/Dexter/Documents/ML_files/171009_Retrain_with_Inception_v3/output_graph.pb',
help='Where to save the trained graph.'
)
parser.add_argument(
'--intermediate_output_graphs_dir',
type=str,
#default='/tmp/intermediate_graph/',
default='C:/Users/Dexter/Documents/ML_files/171009_Retrain_with_Inception_v3/intermediate_graph/',
help='Where to save the intermediate graphs.'
)
parser.add_argument(
'--intermediate_store_frequency',
type=int,
default=0,
help="""\
How many steps to store intermediate graph. If "0" then will not
store.\
"""
)
parser.add_argument(
'--output_labels',
type=str,
#default='/tmp/output_labels.txt',
default='C:/Users/Dexter/Documents/ML_files/171009_Retrain_with_Inception_v3/output_labels.txt',
help='Where to save the trained graph\'s labels.'
)
parser.add_argument(
'--summaries_dir',
type=str,
#default='/tmp/retrain_logs',
default='C:/Users/Dexter/Documents/ML_files/171009_Retrain_with_Inception_v3/retrain_logs',
help='Where to save summary logs for TensorBoard.'
)
parser.add_argument(
'--how_many_training_steps',
type=int,
default=4000,
help='How many training steps to run before ending.'
)
parser.add_argument(
'--learning_rate',
type=float,
default=0.01,
help='How large a learning rate to use when training.'
)
parser.add_argument(
'--testing_percentage',
type=int,
default=10,
help='What percentage of images to use as a test set.'
)
parser.add_argument(
'--validation_percentage',
type=int,
default=10,
help='What percentage of images to use as a validation set.'
)
parser.add_argument(
'--eval_step_interval',
type=int,
default=10,
help='How often to evaluate the training results.'
)
parser.add_argument(
'--train_batch_size',
type=int,
default=100,
help='How many images to train on at a time.'
)
parser.add_argument(
'--test_batch_size',
type=int,
default=-1,
help="""\
How many images to test on. This test set is only used once, to evaluate
the final accuracy of the model after training completes.
A value of -1 causes the entire test set to be used, which leads to more
stable results across runs.\
"""
)
parser.add_argument(
'--validation_batch_size',
type=int,
default=100,
help="""\
How many images to use in an evaluation batch. This validation set is
used much more often than the test set, and is an early indicator of how
accurate the model is during training.
A value of -1 causes the entire validation set to be used, which leads to
more stable results across training iterations, but may be slower on large
training sets.\
"""
)
parser.add_argument(
'--print_misclassified_test_images',
default=False,
help="""\
Whether to print out a list of all misclassified test images.\
""",
action='store_true'
)
parser.add_argument(
'--model_dir',
type=str,
#default='/tmp/imagenet',
default='C:/Users/Dexter/Documents/ML_files/171009_Retrain_with_Inception_v3/inception-2015-12-05',
help="""\
Path to classify_image_graph_def.pb,
imagenet_synset_to_human_label_map.txt, and
imagenet_2012_challenge_label_map_proto.pbtxt.\
"""
)
parser.add_argument(
'--bottleneck_dir',
type=str,
#default='/tmp/bottleneck',
default='C:/Users/Dexter/Documents/ML_files/171009_Retrain_with_Inception_v3/bottleneck',
help='Path to cache bottleneck layer values as files.'
)
parser.add_argument(
'--final_tensor_name',
type=str,
default='final_result',
help="""\
The name of the output classification layer in the retrained graph.\
"""
)
parser.add_argument(
'--flip_left_right',
default=False,
help="""\
Whether to randomly flip half of the training images horizontally.\
""",
action='store_true'
)
parser.add_argument(
'--random_crop',
type=int,
default=0,
help="""\
A percentage determining how much of a margin to randomly crop off the
training images.\
"""
)
parser.add_argument(
'--random_scale',
type=int,
default=0,
help="""\
A percentage determining how much to randomly scale up the size of the
training images by.\
"""
)
parser.add_argument(
'--random_brightness',
type=int,
default=0,
help="""\
A percentage determining how much to randomly multiply the training image
input pixels up or down by.\
"""
)
parser.add_argument(
'--architecture',
type=str,
default='inception_v3',
help="""\
Which model architecture to use. 'inception_v3' is the most accurate, but
also the slowest. For faster or smaller models, chose a MobileNet with the
form 'mobilenet_<parameter size>_<input_size>[_quantized]'. For example,
'mobilenet_1.0_224' will pick a model that is 17 MB in size and takes 224
pixel input images, while 'mobilenet_0.25_128_quantized' will choose a much
less accurate, but smaller and faster network that's 920 KB on disk and
takes 128x128 images. See https://research.googleblog.com/2017/06/mobilenets-open-source-models-for.html
for more information on Mobilenet.\
""")
6. 运行
程序首先会载入inception模型,然后将数据集输入生成Bottleneck文件存放在本地。然后再使用Bottleneck文件训练最后的softmax layer。训练结果如下:
i7-6600U Win10
此图为使用Win10,Anaconda,Python 3.5 CPU训练,大概耗时10~20分钟吧,忘记在代码中加入计时了。最后几轮训练输出如下,最终精度为91.8%:
INFO:tensorflow:2017-10-13 13:30:55.851475: Step 3980: Train accuracy = 97.0%
INFO:tensorflow:2017-10-13 13:30:55.852477: Step 3980: Cross entropy = 0.105862
INFO:tensorflow:2017-10-13 13:30:55.973800: Step 3980: Validation accuracy = 93.0% (N=100)
INFO:tensorflow:2017-10-13 13:30:57.164969: Step 3990: Train accuracy = 99.0%
INFO:tensorflow:2017-10-13 13:30:57.165970: Step 3990: Cross entropy = 0.083922
INFO:tensorflow:2017-10-13 13:30:57.287292: Step 3990: Validation accuracy = 85.0% (N=100)
INFO:tensorflow:2017-10-13 13:30:58.387826: Step 3999: Train accuracy = 94.0%
INFO:tensorflow:2017-10-13 13:30:58.388830: Step 3999: Cross entropy = 0.175846
INFO:tensorflow:2017-10-13 13:30:58.514203: Step 3999: Validation accuracy = 90.0% (N=100)
INFO:tensorflow:Final test accuracy = 91.8% (N=732)
INFO:tensorflow:Froze 2 variables.
Converted 2 variables to const ops.
An exception has occurred, use %tb to see the full traceback.
SystemExit
C:\Users\Dexter\Anaconda2\envs\TensorFlow_Py35\lib\site-packages\IPython\core\interactiveshell.py:2870: UserWarning: To exit: use 'exit', 'quit', or Ctrl-D.
warn("To exit: use 'exit', 'quit', or Ctrl-D.", stacklevel=1)
注意程序在最后训练结束后会将softmax layer(包含一个全连接层)的两个参数固化,结合剩余的inception模型参数,一起固化为新的模型。
后来又使用1080ti+Ubuntu重新跑了一遍,大概5分钟足矣,不过运气不好,精度只有88.0%。
7. 测试
7.1 下载数据
从网上随便下载一些图片,特地挑选了在5个类别中,但是又很难分辨出的图片:
daisy.jpg roses.jpg
第一个感觉杂糅了雏菊和向日葵,第二个玫瑰还是郁金香真是傻傻分不清。总共准备了20张图片:
test_images
7.2 测试脚本
主要参考TensorFlow学习笔记:使用Inception v3进行图像分类中的代码:
# -*- coding: utf-8 -*-
"""
Created on Fri Oct 13 16:15:16 2017
use_output_graph
使用retrain所训练的迁移后的inception模型来测试
@author: Dexter
"""
import tensorflow as tf
import numpy as np
import os
from PIL import Image
import matplotlib.pyplot as plt
model_dir = 'C:/Users/Dexter/Documents/ML_files/171009_Retrain_with_Inception_v3'
model_name = 'output_graph.pb'
image_dir = 'C:/Users/Dexter/Documents/ML_files/171009_Retrain_with_Inception_v3/use_output_graph/test_images'
label_dir = 'C:/Users/Dexter/Documents/ML_files/171009_Retrain_with_Inception_v3/'
label_filename = 'output_labels.txt'
# 读取并创建一个图graph来存放Google训练好的Inception_v3模型(函数)
def create_graph():
with tf.gfile.FastGFile(os.path.join(
model_dir, model_name), 'rb') as f:
# 使用tf.GraphDef()定义一个空的Graph
graph_def = tf.GraphDef()
graph_def.ParseFromString(f.read())
# Imports the graph from graph_def into the current default Graph.
tf.import_graph_def(graph_def, name='')
# 读取标签labels
def load_labels(label_file_dir):
if not tf.gfile.Exists(label_file_dir):
# 预先检测地址是否存在
tf.logging.fatal('File does not exist %s', label_file_dir)
else:
# 读取所有的标签返并回一个list
labels = tf.gfile.GFile(label_file_dir).readlines()
for i in range(len(labels)):
labels[i] = labels[i].strip('\n')
return labels
# 创建graph
create_graph()
# 创建会话,因为是从已有的Inception_v3模型中恢复,所以无需初始化
with tf.Session() as sess:
# Inception_v3模型的最后一层final_result:0的输出
softmax_tensor = sess.graph.get_tensor_by_name('final_result:0')
# 遍历目录
for root, dirs, files in os.walk(image_dir):
for file in files:
# 载入图片
image_data = tf.gfile.FastGFile(os.path.join(root, file), 'rb').read()
# 输入图像(jpg格式)数据,得到softmax概率值(一个shape=(1,1008)的向量)
predictions = sess.run(softmax_tensor,{'DecodeJpeg/contents:0': image_data})
# 将结果转为1维数据
predictions = np.squeeze(predictions)
# 打印图片路径及名称
image_path = os.path.join(root, file)
print(image_path)
# 显示图片
img = Image.open(image_path)
plt.imshow(img)
plt.axis('off')
plt.show()
# 排序,取出前5个概率最大的值(top-5),本数据集一共就5个
# argsort()返回的是数组值从小到大排列所对应的索引值
top_5 = predictions.argsort()[-5:][::-1]
for label_index in top_5:
# 获取分类名称
label_name = load_labels(os.path.join(
label_dir, label_filename))[label_index]
# 获取该分类的置信度
label_score = predictions[label_index]
print('%s (score = %.5f)' % (label_name, label_score))
print()
7.3 输出
部分结果如下:
C:/Users/Dexter/Documents/ML_files/171009_Retrain_with_Inception_v3/use_output_graph/test_images\18.jpg
roses (score = 0.99892)
tulips (score = 0.00065)
sunflowers (score = 0.00032)
dandelion (score = 0.00007)
daisy (score = 0.00003)
C:/Users/Dexter/Documents/ML_files/171009_Retrain_with_Inception_v3/use_output_graph/test_images\19.jpg
roses (score = 0.99359)
tulips (score = 0.00633)
sunflowers (score = 0.00008)
daisy (score = 0.00000)
dandelion (score = 0.00000)
当然,并不是所有的图都预测的非常准确,有几张图,就差点翻车:
sunflowers
roses
不过这个扭曲的雏菊倒是毫无压力:
daisy
7.4 label_image.py
发现Google同样提供了测试脚本:
tensorflow / tensorflow /examples / image_retraining / label_image.py
不过这个脚本只能预测一张照片,感觉不如自己写的,使用时修改命令行参数的默认值即可,源码(未修改)如下,仅用做Mark:
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Simple image classification with Inception.
Run image classification with your model.
This script is usually used with retrain.py found in this same
directory.
This program creates a graph from a saved GraphDef protocol buffer,
and runs inference on an input JPEG image. You are required
to pass in the graph file and the txt file.
It outputs human readable strings of the top 5 predictions along with
their probabilities.
Change the --image_file argument to any jpg image to compute a
classification of that image.
Example usage:
python label_image.py --graph=retrained_graph.pb
--labels=retrained_labels.txt
--image=flower_photos/daisy/54377391_15648e8d18.jpg
NOTE: To learn to use this file and retrain.py, please see:
https://codelabs.developers.google.com/codelabs/tensorflow-for-poets
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import argparse
import sys
import tensorflow as tf
parser = argparse.ArgumentParser()
parser.add_argument(
'--image', required=True, type=str, help='Absolute path to image file.')
parser.add_argument(
'--num_top_predictions',
type=int,
default=5,
help='Display this many predictions.')
parser.add_argument(
'--graph',
required=True,
type=str,
help='Absolute path to graph file (.pb)')
parser.add_argument(
'--labels',
required=True,
type=str,
help='Absolute path to labels file (.txt)')
parser.add_argument(
'--output_layer',
type=str,
default='final_result:0',
help='Name of the result operation')
parser.add_argument(
'--input_layer',
type=str,
default='DecodeJpeg/contents:0',
help='Name of the input operation')
def load_image(filename):
"""Read in the image_data to be classified."""
return tf.gfile.FastGFile(filename, 'rb').read()
def load_labels(filename):
"""Read in labels, one label per line."""
return [line.rstrip() for line in tf.gfile.GFile(filename)]
def load_graph(filename):
"""Unpersists graph from file as default graph."""
with tf.gfile.FastGFile(filename, 'rb') as f:
graph_def = tf.GraphDef()
graph_def.ParseFromString(f.read())
tf.import_graph_def(graph_def, name='')
def run_graph(image_data, labels, input_layer_name, output_layer_name,
num_top_predictions):
with tf.Session() as sess:
# Feed the image_data as input to the graph.
# predictions will contain a two-dimensional array, where one
# dimension represents the input image count, and the other has
# predictions per class
softmax_tensor = sess.graph.get_tensor_by_name(output_layer_name)
predictions, = sess.run(softmax_tensor, {input_layer_name: image_data})
# Sort to show labels in order of confidence
top_k = predictions.argsort()[-num_top_predictions:][::-1]
for node_id in top_k:
human_string = labels[node_id]
score = predictions[node_id]
print('%s (score = %.5f)' % (human_string, score))
return 0
def main(argv):
"""Runs inference on an image."""
if argv[1:]:
raise ValueError('Unused Command Line Args: %s' % argv[1:])
if not tf.gfile.Exists(FLAGS.image):
tf.logging.fatal('image file does not exist %s', FLAGS.image)
if not tf.gfile.Exists(FLAGS.labels):
tf.logging.fatal('labels file does not exist %s', FLAGS.labels)
if not tf.gfile.Exists(FLAGS.graph):
tf.logging.fatal('graph file does not exist %s', FLAGS.graph)
# load image
image_data = load_image(FLAGS.image)
# load labels
labels = load_labels(FLAGS.labels)
# load graph, which is stored in the default session
load_graph(FLAGS.graph)
run_graph(image_data, labels, FLAGS.input_layer, FLAGS.output_layer,
FLAGS.num_top_predictions)
if __name__ == '__main__':
FLAGS, unparsed = parser.parse_known_args()
tf.app.run(main=main, argv=sys.argv[:1]+unparsed)
8. TensorBoard 可视化
打开命令行,输入:
tensorboard --logdir /tmp/retrain_logs
# 将/tmp/retrain_logs改为实际存储目录,如果之前有修改命令行参数默认值
然而,在浏览器中输入
http://Dexter:6006
并不能打开:解决方法:输入
localhost:6006
即可。tensorboard
网友评论
与现在的TensorFlow版本不兼容,会出现:
Cannot interpret feed_dict key as Tensor: The name 'DecodeJpeg/contents:0' refers to a Tensor which does not exist. The operation, 'DecodeJpeg/contents', does not exist in the graph.错误
你知道如何查到现在这个'DecodeJpeg/contents:0' 键值对应的现在TensorFlow的版本的键值吗?感谢!