ROI Net

作者: 翻开日记 | 来源:发表于2018-08-08 11:37 被阅读0次
    #!/usr/bin/env python
    # -*- coding: utf-8 -*-
    # @Time    : 18-8-7 下午5:47
    # @File    : fast_roi.py
    # @Software: PyCharm
    # @Author  : wxw
    # @Contact : xwwei@lighten.ai
    # @Desc    : roi tf net
    
    import tensorflow as tf
    from collections import OrderedDict
    from utils import config
    import numpy as np
    
    
    class RoiNets:
        def __init__(self, scores, endpoints, height, width, net_name, is_training=True):
            self.is_training = is_training
            self.scores = scores
            self.endpoints = endpoints
            self.witdh = width
            self.height = height
            self.net_name = net_name
            self.batch_norm = {
                "is_training": is_training,
                "center": True,
                "scale": True,
                "decay": 0.9,
                "epsilon": 0.001,
            }
            self.positions = self.get_position(self.scores)
            self.box_ind = tf.constant(np.arange(config.batch_size), tf.int32)
            self.cut_maps = self.get_roi_maps()
            self.page_info = self.get_page_info()
    
        def get_position(self, scores):
            scores = tf.nn.softmax(scores)
            scores = tf.split(scores, num_or_size_splits=2, axis=3)[1]
            map = tf.identity(scores)
            map = tf.reshape(map, [config.batch_size, -1])
            max_idx = tf.argmax(map, axis=1)
            heigth = tf.cast(tf.expand_dims(max_idx // 224, 1) / 224, tf.float32)
            width = tf.cast(tf.expand_dims(max_idx % 224, 1) / 224, tf.float32)
            b_h = tf.maximum(0.0, (heigth - 0.25))
            b_w = tf.maximum(0.0, (width - 0.25))
            e_w = tf.minimum(1.0, (width + 0.25))
            e_h = heigth
            return tf.concat([b_h, b_w, e_h, e_w], axis=1)
    
        def get_roi_maps(self):
            number = len(self.net_name) - 1
            cut_maps = OrderedDict()
            self.chanels = []
            for i in range(number):
                cut_name = "cut_map_%d" % i
                net = self.endpoints[self.net_name[i]]
                self.chanels.append(net.get_shape()[3].value)
                cut_maps[cut_name] = tf.image.crop_and_resize(image=net,
                                                              boxes=self.positions,
                                                              box_ind=self.box_ind,
                                                              crop_size=[self.height[i],
                                                                         self.witdh[i]])
            net = self.endpoints[self.net_name[number]]
            cut_maps["cut_map_%d" % number] = net
            self.chanels.append(net.get_shape()[3].value)
    
            for idx, amap in enumerate(cut_maps):
                print('[%d]:' % idx, cut_maps[amap])
            return cut_maps
    

    相关文章

      网友评论

          本文标题:ROI Net

          本文链接:https://www.haomeiwen.com/subject/mwvtbftx.html