#!/usr/bin/env python
# -*- coding: utf-8 -*-
# @Time : 18-8-7 下午5:47
# @File : fast_roi.py
# @Software: PyCharm
# @Author : wxw
# @Contact : xwwei@lighten.ai
# @Desc : roi tf net
import tensorflow as tf
from collections import OrderedDict
from utils import config
import numpy as np
class RoiNets:
def __init__(self, scores, endpoints, height, width, net_name, is_training=True):
self.is_training = is_training
self.scores = scores
self.endpoints = endpoints
self.witdh = width
self.height = height
self.net_name = net_name
self.batch_norm = {
"is_training": is_training,
"center": True,
"scale": True,
"decay": 0.9,
"epsilon": 0.001,
}
self.positions = self.get_position(self.scores)
self.box_ind = tf.constant(np.arange(config.batch_size), tf.int32)
self.cut_maps = self.get_roi_maps()
self.page_info = self.get_page_info()
def get_position(self, scores):
scores = tf.nn.softmax(scores)
scores = tf.split(scores, num_or_size_splits=2, axis=3)[1]
map = tf.identity(scores)
map = tf.reshape(map, [config.batch_size, -1])
max_idx = tf.argmax(map, axis=1)
heigth = tf.cast(tf.expand_dims(max_idx // 224, 1) / 224, tf.float32)
width = tf.cast(tf.expand_dims(max_idx % 224, 1) / 224, tf.float32)
b_h = tf.maximum(0.0, (heigth - 0.25))
b_w = tf.maximum(0.0, (width - 0.25))
e_w = tf.minimum(1.0, (width + 0.25))
e_h = heigth
return tf.concat([b_h, b_w, e_h, e_w], axis=1)
def get_roi_maps(self):
number = len(self.net_name) - 1
cut_maps = OrderedDict()
self.chanels = []
for i in range(number):
cut_name = "cut_map_%d" % i
net = self.endpoints[self.net_name[i]]
self.chanels.append(net.get_shape()[3].value)
cut_maps[cut_name] = tf.image.crop_and_resize(image=net,
boxes=self.positions,
box_ind=self.box_ind,
crop_size=[self.height[i],
self.witdh[i]])
net = self.endpoints[self.net_name[number]]
cut_maps["cut_map_%d" % number] = net
self.chanels.append(net.get_shape()[3].value)
for idx, amap in enumerate(cut_maps):
print('[%d]:' % idx, cut_maps[amap])
return cut_maps
网友评论