美文网首页
Tensorflow(12)训练自己的数据

Tensorflow(12)训练自己的数据

作者: Thinkando | 来源:发表于2018-11-27 23:49 被阅读120次

    1. 去官网下载

    image.png

    2. 网上下载一些文件,做成像我这样的

    image.png

    3. 跑测试数据

    # 用到了retrain.py 文件
    python /Users/chengkai/Documents/06_code/code/tensorflow/hub-master/examples/image_retraining/retrain.py 
    # 只跑最后一层
    --bottleneck_dir bottleneck 
    # 训练200遍
    --how_many_training_steps 200 
    # 用到tensorflow(11) inception 模型
    --model_dir /Users/chengkai/Documents/06_code/code/tensorflow/inception_model/ 
    # 输出文件
    --output_graph output_graph.pb 
    --output_labels output_labels.txt 
    # 用来训练的图片
    --image_dir /Users/chengkai/Documents/06_code/code/tensorflow/images/
    
    image.png

    4. 跑测试数据

    image.png
    
    # coding: utf-8
    
    import tensorflow as tf
    import os
    import numpy as np
    import re
    from PIL import Image
    import matplotlib.pyplot as plt
    
    
    lines = tf.gfile.GFile('/Users/chengkai/Documents/06_code/code/tensorflow/hub-master/examples/image_retraining/output_labels.txt').readlines()
    uid_to_human = {}
    #一行一行读取数据
    for uid,line in enumerate(lines) :
        #去掉换行符
        line=line.strip('\n')
        uid_to_human[uid] = line
    
    def id_to_string(node_id):
        if node_id not in uid_to_human:
            return ''
        return uid_to_human[node_id]
    
    #创建一个图来存放google训练好的模型
    with tf.gfile.FastGFile('/Users/chengkai/Documents/06_code/code/tensorflow/hub-master/examples/image_retraining/output_graph.pb', 'rb') as f:
        graph_def = tf.GraphDef()
        graph_def.ParseFromString(f.read())
        tf.import_graph_def(graph_def, name='')
    
    
    with tf.Session() as sess:
        softmax_tensor = sess.graph.get_tensor_by_name('final_result:0')
        #遍历目录
        for root,dirs,files in os.walk('/Users/chengkai/Documents/06_code/code/tensorflow/hub-master/examples/image_retraining/test/'):
            for file in files:
                if file.startswith("."):
                    continue
                print(file)
                #载入图片
                image_data = tf.gfile.FastGFile(os.path.join(root,file), 'rb').read()
                predictions = sess.run(softmax_tensor,{'DecodeJpeg/contents:0': image_data})#图片格式是jpg格式
                predictions = np.squeeze(predictions)#把结果转为1维数据
    
                #打印图片路径及名称
                image_path = os.path.join(root,file)
                print(image_path)
                #显示图片
                img=Image.open(image_path)
                plt.imshow(img)
                plt.axis('off')
                plt.show()
    
                #排序
                top_k = predictions.argsort()[::-1]
                print(top_k)
                for node_id in top_k:     
                    #获取分类名称
                    human_string = id_to_string(node_id)
                    #获取该分类的置信度
                    score = predictions[node_id]
                    print('%s (score = %.5f)' % (human_string, score))
                print()
    
    # 这一步我报错了
    TypeError: Cannot interpret feed_dict key as Tensor: The name 'DecodeJpeg/contents:0' refers to a Tensor which does not exist. The operation, 'DecodeJpeg/contents', does not exist in the graph.
    
    image.png

    相关文章

      网友评论

          本文标题:Tensorflow(12)训练自己的数据

          本文链接:https://www.haomeiwen.com/subject/poonqqtx.html