1. 去官网下载
image.png
2. 网上下载一些文件,做成像我这样的
image.png
3. 跑测试数据
# 用到了retrain.py 文件
python /Users/chengkai/Documents/06_code/code/tensorflow/hub-master/examples/image_retraining/retrain.py
# 只跑最后一层
--bottleneck_dir bottleneck
# 训练200遍
--how_many_training_steps 200
# 用到tensorflow(11) inception 模型
--model_dir /Users/chengkai/Documents/06_code/code/tensorflow/inception_model/
# 输出文件
--output_graph output_graph.pb
--output_labels output_labels.txt
# 用来训练的图片
--image_dir /Users/chengkai/Documents/06_code/code/tensorflow/images/
image.png
4. 跑测试数据
image.png
# coding: utf-8
import tensorflow as tf
import os
import numpy as np
import re
from PIL import Image
import matplotlib.pyplot as plt
lines = tf.gfile.GFile('/Users/chengkai/Documents/06_code/code/tensorflow/hub-master/examples/image_retraining/output_labels.txt').readlines()
uid_to_human = {}
#一行一行读取数据
for uid,line in enumerate(lines) :
#去掉换行符
line=line.strip('\n')
uid_to_human[uid] = line
def id_to_string(node_id):
if node_id not in uid_to_human:
return ''
return uid_to_human[node_id]
#创建一个图来存放google训练好的模型
with tf.gfile.FastGFile('/Users/chengkai/Documents/06_code/code/tensorflow/hub-master/examples/image_retraining/output_graph.pb', 'rb') as f:
graph_def = tf.GraphDef()
graph_def.ParseFromString(f.read())
tf.import_graph_def(graph_def, name='')
with tf.Session() as sess:
softmax_tensor = sess.graph.get_tensor_by_name('final_result:0')
#遍历目录
for root,dirs,files in os.walk('/Users/chengkai/Documents/06_code/code/tensorflow/hub-master/examples/image_retraining/test/'):
for file in files:
if file.startswith("."):
continue
print(file)
#载入图片
image_data = tf.gfile.FastGFile(os.path.join(root,file), 'rb').read()
predictions = sess.run(softmax_tensor,{'DecodeJpeg/contents:0': image_data})#图片格式是jpg格式
predictions = np.squeeze(predictions)#把结果转为1维数据
#打印图片路径及名称
image_path = os.path.join(root,file)
print(image_path)
#显示图片
img=Image.open(image_path)
plt.imshow(img)
plt.axis('off')
plt.show()
#排序
top_k = predictions.argsort()[::-1]
print(top_k)
for node_id in top_k:
#获取分类名称
human_string = id_to_string(node_id)
#获取该分类的置信度
score = predictions[node_id]
print('%s (score = %.5f)' % (human_string, score))
print()
# 这一步我报错了
TypeError: Cannot interpret feed_dict key as Tensor: The name 'DecodeJpeg/contents:0' refers to a Tensor which does not exist. The operation, 'DecodeJpeg/contents', does not exist in the graph.
image.png
网友评论