美文网首页Tensorflow
使用单个模型文件进行预测

使用单个模型文件进行预测

作者: 拉赫曼 | 来源:发表于2018-06-29 15:09 被阅读0次

    先上代码

    import tensorflow as tf
    import numpy as np
    from tensorflow.python.platform import gfile
    from tensorflow.python.lib.io import file_io
    
    input_tensor_key = 'Placeholder:0'
    
    def loadNpData(filename):
        tensor_key_feed_dict = {}
    
        #inputs = preprocess_inputs_arg_string(inputs_str)
        data = np.load(file_io.FileIO(filename, mode='r'))
    
        # When no key is specified for the input file.
        # Check if npz file only contains a single numpy ndarray.
        if isinstance(data, np.lib.npyio.NpzFile):
            variable_name_list = data.files
            if len(variable_name_list) != 1:
                raise RuntimeError(
                    'Input file %s contains more than one ndarrays. Please specify '
                    'the name of ndarray to use.' % filename)
            tensor_key_feed_dict[input_tensor_key] = data[variable_name_list[0]]
        else:
            tensor_key_feed_dict[input_tensor_key] = data
        return tensor_key_feed_dict
    
    with tf.Session() as sess:
        # 定义模型文件及样本测试文件
        model_filename = 'merge1_graph.pb'
        example_png = 'examples.npy'
        # 加载npy格式的图片测试样本数据
        image_data = loadNpData(example_png)
        #加载模型文件
        with gfile.FastGFile(model_filename, 'rb') as f:
            graph_def = tf.GraphDef();
            graph_def.ParseFromString(f.read())
    
        # 获取输入节点的tensor
        inputs = sess.graph.get_tensor_by_name("Placeholder:0");
        #打印输入节点的信息
        #print inputs
        # 导入计算图,定义输入节点及输出节点
        output = tf.import_graph_def(graph_def, input_map={'Placeholder:0':inputs}, return_elements=[ 'ArgMax:0','Softmax:0']) 
        # 打印输出节点的信息
        #print output
        results = sess.run(output, feed_dict={inputs:image_data[input_tensor_key]})
        print 'ArgMax result(预测结果对应的标签值):'  
        print results[0]
        print 'Softmax result(最后一层的输出):'
        print results[1]
        # 输出node详细信息,此处默认只打印第一个节点
        for node in graph_def.node:
            print node
            break
    

    运行输出

    ArgMax result(预测结果对应的标签值):
    [3 3]
    Softmax result(最后一层的输出):
    [[4.1668140e-12 9.0696268e-18 6.4261091e-13 9.9999940e-01 1.7161388e-30
      5.4321697e-07 7.6357951e-09 6.3293229e-19 1.3812791e-13 1.5360580e-12]
     [1.1472046e-05 3.3404951e-10 6.0365837e-09 9.9997592e-01 9.8635665e-15
      5.7557719e-07 1.1977763e-05 1.6275100e-16 7.2288098e-10 5.0601763e-08]]
    
    

    此处加载的关键在于tf.import_graph_def函数的参数配置,三个参数graph_def input_map return_elements

    第一个参数是导入的图
    input_map是指定输入节点,如果不指定,后面run的时候会报错 ==You must feed a value for placeholder tensor 'Placeholder'==

    return_elements 是指定运算后的输出节点,此处就是我们想要得到的标签估计值 ArgMax 以及 最后一层节点输出 Softmax

    模型的测试参考 将Tensorflow模型导出为单个文件

    相关文章

      网友评论

        本文标题:使用单个模型文件进行预测

        本文链接:https://www.haomeiwen.com/subject/rzsayftx.html