ROI Net

作者: 翻开日记 | 来源:发表于2018-08-08 11:37 被阅读0次
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# @Time    : 18-8-7 下午5:47
# @File    : fast_roi.py
# @Software: PyCharm
# @Author  : wxw
# @Contact : xwwei@lighten.ai
# @Desc    : roi tf net

import tensorflow as tf
from collections import OrderedDict
from utils import config
import numpy as np


class RoiNets:
    def __init__(self, scores, endpoints, height, width, net_name, is_training=True):
        self.is_training = is_training
        self.scores = scores
        self.endpoints = endpoints
        self.witdh = width
        self.height = height
        self.net_name = net_name
        self.batch_norm = {
            "is_training": is_training,
            "center": True,
            "scale": True,
            "decay": 0.9,
            "epsilon": 0.001,
        }
        self.positions = self.get_position(self.scores)
        self.box_ind = tf.constant(np.arange(config.batch_size), tf.int32)
        self.cut_maps = self.get_roi_maps()
        self.page_info = self.get_page_info()

    def get_position(self, scores):
        scores = tf.nn.softmax(scores)
        scores = tf.split(scores, num_or_size_splits=2, axis=3)[1]
        map = tf.identity(scores)
        map = tf.reshape(map, [config.batch_size, -1])
        max_idx = tf.argmax(map, axis=1)
        heigth = tf.cast(tf.expand_dims(max_idx // 224, 1) / 224, tf.float32)
        width = tf.cast(tf.expand_dims(max_idx % 224, 1) / 224, tf.float32)
        b_h = tf.maximum(0.0, (heigth - 0.25))
        b_w = tf.maximum(0.0, (width - 0.25))
        e_w = tf.minimum(1.0, (width + 0.25))
        e_h = heigth
        return tf.concat([b_h, b_w, e_h, e_w], axis=1)

    def get_roi_maps(self):
        number = len(self.net_name) - 1
        cut_maps = OrderedDict()
        self.chanels = []
        for i in range(number):
            cut_name = "cut_map_%d" % i
            net = self.endpoints[self.net_name[i]]
            self.chanels.append(net.get_shape()[3].value)
            cut_maps[cut_name] = tf.image.crop_and_resize(image=net,
                                                          boxes=self.positions,
                                                          box_ind=self.box_ind,
                                                          crop_size=[self.height[i],
                                                                     self.witdh[i]])
        net = self.endpoints[self.net_name[number]]
        cut_maps["cut_map_%d" % number] = net
        self.chanels.append(net.get_shape()[3].value)

        for idx, amap in enumerate(cut_maps):
            print('[%d]:' % idx, cut_maps[amap])
        return cut_maps

相关文章

网友评论

      本文标题:ROI Net

      本文链接:https://www.haomeiwen.com/subject/mwvtbftx.html