美文网首页我爱编程
【Tensorflow】学习二:图像分类器

【Tensorflow】学习二:图像分类器

作者: 下里巴人也 | 来源:发表于2017-11-08 14:39 被阅读962次

    本机环境
    centos6.8 + python2.7 + tensorflow0.11

    下载tensorflow源码

    从github上拉去代码并切换到0.11版本:

    git clone https://github.com/tensorflow/tensorflow
    git checkout r0.11

    google-Inception模型示例

    执行如下命令,利用google的inception模型识别图片space_shuttle.jpg

    cd tensorflow/models/images/imagenet/
    python classify_image.py --image_file /home/xiabing/TensorFlow_pics/space_shuttle.jpg

    可以看到识别结果如下:


    space_shuttle_result.jpg

    分析classify_image.py

    下面看看classify_image.py的源码
    classify_image.py会首先下载分类器模型:
    DATA_URL = 'http://download.tensorflow.org/models/image/imagenet/inception-2015-12-05.tgz'
    下载后会放到本地/tmp/imagenet/路径下:

    inception.jpg

    训练自己的分类模型

    使用tensorflow中examples中的image_retraining来retraining谷歌的inception模型

    准备图片数据

    准备要训练的每个分类,需要有个对应的文件夹(因为每个子文件夹内的各个图片的label标签就是取分类文件夹名的)类似以下这种:

    fruit/banna/
    fruit/apple/

    每个分类内的数据格式没有规定,本例如下:


    每个分类下图片.jpg

    使用retraining.py训练

    调用如下命令开始训练,参数详解参见retrain.py文件:

    python /home/xiabing/TensorFlow/tensorflow/tensorflow/examples/image_retraining/retrain.py --bottleneck_dir /home/xiabing/sd_classify_pics/bottleneck --how_many_training_steps 4000 --model_dir /home/xiabing/sd_classify_pics/model --output_graph /home/xiabing/sd_classify_pics/output_graph.pb --output_labels /home/xiabing/sd_classify_pics/output_labels.txt --image_dir /home/xiabing/TensorFlow_pics/fruit/

    首次调用会出现如下错误:

    ImportError: cannot import name graph_util

    解决办法:

    修改retrain.py,把
    from tensorflow.python.framework import graph_util
    替换为
    from tensorflow.python.client import graph_util

    再重新执行上面命令,看到如下打印表示训练完成:


    训练完成.jpg

    训练结果

    训练完成后,会在当前目录下生成下面两个文件。查看标签文件,会看到banana和apple。


    训练产生结果文件.jpg
    labels内容.jpg

    使用训练好的模型

    在训练结果路径下新建test.py文件,加入如下代码:

      import tensorflow as tf
      import sys
    
      image_file = sys.argv[1]
      #print(image_file)
    
      image = tf.gfile.FastGFile(image_file, 'rb').read()
    
      labels = []
      for label in tf.gfile.GFile("output_labels.txt"):
          labels.append(label.rstrip())
    
      with tf.gfile.FastGFile("output_graph.pb", 'rb') as f:
      graph_def = tf.GraphDef()
      graph_def.ParseFromString(f.read())
      tf.import_graph_def(graph_def, name='')
    
      with tf.Session() as sess:
      softmax_tensor =     sess.graph.get_tensor_by_name('final_result:0')
      predict = sess.run(softmax_tensor, {'DecodeJpeg/contents:0': image})
    
      top = predict[0].argsort()[-len(predict[0]):][::-1]
      for index in top:
            human_string = labels[index]
            score = predict[0][index]
            print(human_string, score)         
    

    测试训练好的模型:

    python /home/xiabing/sd_classify_pics/test.py /home/xiabing/TensorFlow_pics/1510114397170.jpg

    原始图片:


    1510114397170.jpg

    测试结果:


    结果.jpg

    相关文章

      网友评论

        本文标题:【Tensorflow】学习二:图像分类器

        本文链接:https://www.haomeiwen.com/subject/epqemxtx.html