美文网首页
TensorFlow Faster R-CNN ckpt模型转p

TensorFlow Faster R-CNN ckpt模型转p

作者: Andrew_jidw | 来源:发表于2020-11-09 16:35 被阅读0次
       tensorflow版本的faster r-cnn模型保存有两种类型:ckpt模型和pb模型,ckpt模型适合训练阶段,pb模型适合正式生产环境。
    

    1、输入和输出tensor的确定

    输入和输出tensor的确定是ckpt模型转pb模型的关键,因为faster r-cnn模型比较大,尝试过使用tensorboard将模型结果可视化出来,但是还是很难找到输出tensor,后来只好通过faster r-cnn测试脚本demo.py来逐句的看源代码了。最终确定了输入和输出tensor。

        # 定义输出的张量名称
        input_image_tensor = sess.graph.get_tensor_by_name("Placeholder:0")
        tensor_info = sess.graph.get_tensor_by_name("Placeholder_1:0")
    
        biasadd = sess.graph.get_tensor_by_name("vgg_16_3/cls_score/BiasAdd:0")
        score = sess.graph.get_tensor_by_name("vgg_16_3/cls_prob:0")
        bbox = sess.graph.get_tensor_by_name("add:0")
        rois = sess.graph.get_tensor_by_name("vgg_16_1/rois/concat:0")
    
    

    下面简单介绍一下过程。
    调用顺序tools/demo.py->lib/model/test.py->lib/net/network.py
    下面我们逐个看这几个脚本
    1.1、tools/demo.py
    程序开始加载模型、获取图片,然后调用demo()函数开始做识别(我针对自己的应用场景对源代码做了些修改,大体没变)

    def demo(sess, net, image_name,out_file):
        """Detect object classes in an image using pre-computed object proposals."""
        im = cv2.imdecode(np.fromfile(image_name,dtype=np.uint8),1)
    
        # Detect all object classes and regress object bounds
        timer = Timer()
        timer.tic()
        scores, boxes = im_detect(sess, net, im)
        timer.toc()
        print('Detection took {:.3f}s for {:d} object proposals'.format(timer.total_time, boxes.shape[0]))
    
        # Visualize detections for each class
        CONF_THRESH = 0.8
        NMS_THRESH = 0.3
    
        im = im[:, :, (2, 1, 0)]
        for cls_ind, cls in enumerate(CLASSES[1:]):
            cls_ind += 1 # because we skipped background
            cls_boxes = boxes[:, 4*cls_ind:4*(cls_ind + 1)]
            cls_scores = scores[:, cls_ind]
            dets = np.hstack((cls_boxes,cls_scores[:, np.newaxis])).astype(np.float32)
            keep = nms(dets, NMS_THRESH)
            dets = dets[keep, :]
            vis_detections(im,cls, dets,out_file, thresh=CONF_THRESH)
        cv2.imencode('.jpg',im)[1].tofile(out_file)
    
    

    可以看到这个函数里关键部分就是下面这句,调用了lib/model/test.py里的im_detect()方法,返回的是分数和检测框。

    scores, boxes = im_detect(sess, net, im)
    
    

    1.2、lib/model/test.py

    def im_detect(sess, net, im):
      blobs, im_scales = _get_blobs(im)
      assert len(im_scales) == 1, "Only single-image batch implemented"
    
      im_blob = blobs['data']
      blobs['im_info'] = np.array([im_blob.shape[1], im_blob.shape[2], im_scales[0]], dtype=np.float32)
    
      _, scores, bbox_pred, rois = net.test_image(sess, blobs['data'], blobs['im_info'])
    
      boxes = rois[:, 1:5] / im_scales[0]
      scores = np.reshape(scores, [scores.shape[0], -1])
      bbox_pred = np.reshape(bbox_pred, [bbox_pred.shape[0], -1])
    
      if cfg.TEST.BBOX_REG:
        # Apply bounding-box regression deltas
        box_deltas = bbox_pred
        pred_boxes = bbox_transform_inv(boxes, box_deltas)
        pred_boxes = _clip_boxes(pred_boxes, im.shape)
      else:
        # Simply repeat the boxes, once for each class
        pred_boxes = np.tile(boxes, (1, scores.shape[1]))
    
      return scores, pred_boxes
    
    

    im_detect中关键的是下面这一句,通过调用lib/network.py的test_image()方法得到检测得分、检测边框和边框修改正值。

      _, scores, bbox_pred, rois = net.test_image(sess, blobs['data'], blobs['im_info'])
    
    

    1.3、lib/net/network.py

      # only useful during testing mode
      def test_image(self, sess, image, im_info):
    
        feed_dict = {self._image: image,
                     self._im_info: im_info}
    
        cls_score, cls_prob, bbox_pred, rois = sess.run([self._predictions["cls_score"],
                                                         self._predictions['cls_prob'],
                                                         self._predictions['bbox_pred'],
                                                         self._predictions['rois']],
                                                        feed_dict=feed_dict)
    
        return cls_score, cls_prob, bbox_pred, rois
    
    

    现在我们要看的就是sess.run这句话,这句是tensorflow开始run模型,输入的参数就是我们想要的,根据run()的调用方式,前面的列表是输出的tensor,后面的feed_dict字典是喂给图的数据和输入的tensor,这里我们可以提取打印一下这几个参数。

      def test_image(self, sess, image, im_info):
    
        feed_dict = {self._image: image,
                     self._im_info: im_info}
    
        print("输入tensor")
        print(self._image)
        print(self._im_info)
        print("输出tensor")
        print(self._predictions["cls_score"])
        print(self._predictions["cls_prob"])
        print(self._predictions["bbox_pred"])
        print(self._predictions["rois"])
        cls_score, cls_prob, bbox_pred, rois = sess.run([self._predictions["cls_score"],
                                                         self._predictions['cls_prob'],
                                                         self._predictions['bbox_pred'],
                                                         self._predictions['rois']],
                                                        feed_dict=feed_dict)
    
        return cls_score, cls_prob, bbox_pred, rois
    
    

    然后我们来执行一下demo.py做一次目标识别。可以看到下面打印出了tensor

    image

    至此我们得到了需要的输入和输出的tensor名称。

    2、ckpt模型转pb模型

    关于如何将ckpt转pb模型,网上已经有很多介绍了, TensoFlow为我们提供了convert_variables_to_constants()方法,该方法可以固化模型结构,将计算图中的变量取值以常量的形式保存,但是这里面有个问题是,faster r-cnn模型在训练阶段输入的tensor有三个,在预测阶段只需要两个,原因是训练阶段需要输入label作为ground truth。测试发现如果使用通用的方式读取ckpt模型然后固化为pb模型需要有3个输入tensor,无法调用,后来在这里找到解决方法https://github.com/endernewton/tf-faster-rcnn/issues/340,他是用demo.py里加载的模型来进行固化的,固化后发现可以正常调用,因此faster r-cnn的固化方法是借用demo.py加载模型的方式,再固化。

    if __name__ == '__main__':
        cfg.TEST.HAS_RPN = True  # Use RPN for proposals
        #args = parse_args()
    
        f_path = os.path.abspath(os.path.dirname(os.path.dirname(__file__)))
    
        # model path
        # demonet = args.demo_net
        # dataset = args.dataset
        # input_dir = args.input
        # output_dir = args.output
        demonet = "vgg16"
        dataset = "pascal_voc"
        # input_dir = sys.argv[1]
        # output_dir = sys.argv[2]
        input_dir = r"D:\result\ship_tf\input_data"
        output_dir = r"D:\result\ship_tf"
        tfmodel = os.path.join(f_path,'output', demonet, DATASETS[dataset][0], 'default_ship',
                                  NETS[demonet][0])
    
        if not os.path.isfile(tfmodel + '.meta'):
            raise IOError(('{:s} not found.\nDid you download the proper networks from '
                           'our server and place them properly?').format(tfmodel + '.meta'))
    
        # set config
        tfconfig = tf.ConfigProto(allow_soft_placement=True)
        print(tfconfig)
        tfconfig.gpu_options.allow_growth=True
    
        # init session
        sess = tf.Session(config=tfconfig)
        # load network
        if demonet == 'vgg16':
            net = vgg16()
        elif demonet == 'res101':
            net = resnetv1(num_layers=101)
        else:
            raise NotImplementedError
        net.create_architecture("TEST", len(CLASSES),
                              tag='default', anchor_scales=[8, 16, 32])
        saver = tf.train.Saver()
        saver.restore(sess, tfmodel)
    
        #ckpt to pb
        graph = tf.get_default_graph()
        input_graph_def = graph.as_graph_def()
        output_graph = r"D:\result\ship_tf\ship_model_from_demo3.pb"
        output_node_names = "vgg_16_3/cls_prob,add,vgg_16_1/rois/concat,vgg_16_3/cls_score/BiasAdd"
        output_graph_def = graph_util.convert_variables_to_constants(sess, input_graph_def, output_node_names.split(","))
        with tf.gfile.GFile(output_graph, "wb") as f:
            f.write(output_graph_def.SerializeToString())
    

    3、pb模型测试

    通过前两步我们确定了模型输入和输出tensor、得到了pb模型,下面就是直接调用pb模型进行目标识别了。(详细代码见demo_pb.py)
    3.1、输入数据的准备
    先看lib/model/test.py中的im_detect()函数中的下面几行。

      blobs, im_scales = _get_blobs(im)
      assert len(im_scales) == 1, "Only single-image batch implemented"
    
      im_blob = blobs['data']
      blobs['im_info'] = np.array([im_blob.shape[1], im_blob.shape[2], im_scales[0]], dtype=np.float32)
    
    

    调用_get_blobs()方法,传入图片路径,得到blobs和im_scales,blobs字典中的data字段是存储重采样后的图片矩阵,然后将重采样后的图片宽和高、重采样比例三个参数存储成info字段。这个blobs[“info”]就是faster r-cnn模型输入的info信息。
    3.2、模型预测
    这一步主要是加载pb模型和定义输入输出tensor
    模型加载

        p_path=""       #pb模型路径
        with tf.Graph().as_default():
            output_graph_def = tf.GraphDef()
            with open(pb_path, "rb") as f:
                output_graph_def.ParseFromString(f.read())
                tf.import_graph_def(output_graph_def, name="")
    
            with tf.Session() as sess:
                sess.run(tf.global_variables_initializer())
    
    

    确定输入和输出tensor,(下面代码中的im_info是我自己测试的时候写死的一个info信息。)

        # 定义输出的张量名称
        input_image_tensor = sess.graph.get_tensor_by_name("Placeholder:0")
        tensor_info = sess.graph.get_tensor_by_name("Placeholder_1:0")
    
        biasadd = sess.graph.get_tensor_by_name("vgg_16_3/cls_score/BiasAdd:0")
        score = sess.graph.get_tensor_by_name("vgg_16_3/cls_prob:0")
        bbox = sess.graph.get_tensor_by_name("add:0")
        rois = sess.graph.get_tensor_by_name("vgg_16_1/rois/concat:0")
    
        im_info = np.array([600.0, 600.0, 2.34375], dtype=np.float32)
        _, scores, bbox_pred, rois = sess.run([biasadd, score, bbox, rois],
                                              feed_dict={input_image_tensor: image, tensor_info: im_info})
    
    

    3.3、后处理
    模型预测得到4个返回值,其中有用的是scores、bbox_pred、rois,分别是检测框分数,修正值,检测框。我们需要对初步得到的监测框做修正。主要代码为
    lib/model/test.py中的im_detect()函数

      boxes = rois[:, 1:5] / im_scales[0]
      scores = np.reshape(scores, [scores.shape[0], -1])
      bbox_pred = np.reshape(bbox_pred, [bbox_pred.shape[0], -1])
    
      if cfg.TEST.BBOX_REG:
        # Apply bounding-box regression deltas
        box_deltas = bbox_pred
        pred_boxes = bbox_transform_inv(boxes, box_deltas)
        pred_boxes = _clip_boxes(pred_boxes, im.shape)
      else:
        # Simply repeat the boxes, once for each class
        pred_boxes = np.tile(boxes, (1, scores.shape[1]))
    
    

    关键代码是bbox_transform_inv()和_clip_boxes,分别用来修正检测框和裁切超过图片边界的检测框。
    通过以上分析我们发现可以直接调用源码里的预处理和后处理阶段的部分。

    几个关键点

    1、输入输出tensor的确定
    2、读取ckpt模型需要使用demo.py

    补充

    其实做这个工作是为了把深度学习模型集成到我们的项目里,因此需要用C++来调用faster r-cnn模型。下一篇博客会介绍如何用C++来集成深度学习模型进行目标识别。

    转自:Faster R-CNN ckpt模型转pb模型及其调用-python版本

    相关文章

      网友评论

          本文标题:TensorFlow Faster R-CNN ckpt模型转p

          本文链接:https://www.haomeiwen.com/subject/cyzlbktx.html