美文网首页Tensorflow
TensorFlow学习笔记(14)使用SignatureDef

TensorFlow学习笔记(14)使用SignatureDef

作者: 谢昆明 | 来源:发表于2018-11-20 07:49 被阅读0次

    环境:
    Python 3.5.2
    tensorflow : 1.11.0
    ubuntu : 16.04

    保存模型,github代码

      saved_model_dir='./model'                                                                                                                                                                                                                                     
      builder = tf.saved_model.builder.SavedModelBuilder(saved_model_dir)                                                                                                                                                                                           
                                                                                                                                                                                                                                                                    
      # input_x, keep_prob                                                                                                                                                                                                                                          
      inputs = {'input_x': tf.saved_model.utils.build_tensor_info(xs),                                                                                                                                                                                              
               'input_y': tf.saved_model.utils.build_tensor_info(ys),                                                                                                                                                                                              
                'keep_prob': tf.saved_model.utils.build_tensor_info(keep_prob)}                                                                                                                                                                                     
                                                                                                                                                                                                                                                                    
      # prediction 为预测函数,恢复的时候要通过该函数来预测                                                                                                                                                                                                         
      outputs = {'prediction' : tf.saved_model.utils.build_tensor_info(prediction)}                                                                                                                                                                                 
                                                                                                                                                                                                                                                                    
      signature = tf.saved_model.signature_def_utils.build_signature_def(inputs, outputs, 'test_sig_name')                                                                                                                                                          
                                                                                                                                                                                                                                                                    
      with tf.Session() as sess:                                                                                                                                                                                                                                    
          sess.run(init)                                                                                                                                                                                                                                            
                                                                                                                                                                                                                                                                    
          for i in range(1000):                                                                                                                                                                                                                                     
              batch_xs, batch_ys = mnist.train.next_batch(100)                                                                                                                                                                                                      
              sess.run(train_step, feed_dict={xs: batch_xs, ys: batch_ys, keep_prob: 0.5})                                                                                                                                                                          
              if i % 50 == 0:                                                                                                                                                                                                                                       
                  print(compute_accuracy(sess, prediction,                                                                                                                                                                                                          
                      mnist.test.images[:1000], mnist.test.labels[:1000]))                                                                                                                                                                                          
                                                                                                                                                                                                                                                                    
          builder.add_meta_graph_and_variables(sess, ['model_final'], {'test_signature':signature})                                                                                                                                                                 
          builder.save()   
    

    恢复模型github代码

      saved_model_dir='./model'                                                                                                                                                                                                                                     
                                                                                                                                                                                                                                                                    
      signature_key = 'test_signature'                                                                                                                                                                                                                              
      input_key_x = 'input_x'                                                                                                                                                                                                                                       
      input_key_y = 'input_y'                                                                                                                                                                                                                                       
      input_key_keep_prob = 'keep_prob'                                                                                                                                                                                                                             
      output_key_prediction = 'prediction'                                                                                                                                                                                                                          
    
      with tf.Session() as sess:                                                                                                                                                                                                                                    
          meta_graph_def = tf.saved_model.loader.load(sess, ['model_final'], saved_model_dir)                                                                                                                                                                       
                                                                                                                                                                                                                                                                    
          # 从meta_graph_def中取出SignatureDef对象                                                                                                                                                                                                                  
          signature = meta_graph_def.signature_def                                                                                                                                                                                                                  
                                                                                                                                                                                                                                                                    
          # 从signature中找出具体输入输出的tensor name.                                                                                                                                                                                                             
          x_tensor_name = signature[signature_key].inputs[input_key_x].name                                                                                                                                                                                         
          y_tensor_name = signature[signature_key].inputs[input_key_y].name                                                                                                                                                                                         
          keep_prob_tensor_name = signature[signature_key].inputs[input_key_keep_prob].name                                                                                                                                                                         
          prediction_tensor_name = signature[signature_key].outputs[output_key_prediction].name                                                                                                                                                                     
                                                                                                                                                                                                                                                                    
          # 获取tensor 并inference                                                                                                                                                                                                                                  
          input_x = sess.graph.get_tensor_by_name(x_tensor_name)                                                                                                                                                                                                    
          input_y = sess.graph.get_tensor_by_name(y_tensor_name)                                                                                                                                                                                                    
          keep_prob = sess.graph.get_tensor_by_name(keep_prob_tensor_name)                                                                                                                                                                                          
          prediction = sess.graph.get_tensor_by_name(prediction_tensor_name)                                                                                                                                                                                        
    

    通过恢复的模型,来预测结果

                                                                                                                                                                                                                                                                    
          # 测试单个数据                                                                                                                                                                                                                                            
          x = mnist.test.images[index].reshape(1, 784)                                                                                                                                                                                                              
          y = mnist.test.labels[index].reshape(1, 10)  # 转为one-hot形式                                                                                                                                                                                            
          print (y)                                                                                                                                                                                                                                                 
                                                                                                                                                                                                                                                                    
          pred_y = sess.run(prediction, feed_dict={input_x: x, keep_prob : 1 })                                                                                                                                                                                     
          print (pred_y)                                                                                                                                                                                                                                            
                                                                                                                                                                                                                                                                    
          print("Actual class: ", str(sess.run(tf.argmax(y, 1))), \                                                                                                                                                                                                 
                ", predict class ",str(sess.run(tf.argmax(pred_y, 1))), \                                                                                                                                                                                           
                ", predict ", str(sess.run(tf.equal(tf.argmax(y, 1), tf.argmax(pred_y, 1))))                                                                                                                                                                        
                )                                                                                                                                                                                                                                                   
                                                                                                                                                                                                                                                                    
          # 测试数据集                                                                                                                                                                                                                                              
          print(compute_accuracy(sess, prediction, input_x, keep_prob,                                                                                                                                                                                              
            mnist.test.images[:1000], mnist.test.labels[:1000]))     
    

    相关文章

      网友评论

        本文标题:TensorFlow学习笔记(14)使用SignatureDef

        本文链接:https://www.haomeiwen.com/subject/vrttqqtx.html