美文网首页
Tensorflow Channel-Wise Attentio

Tensorflow Channel-Wise Attentio

作者: 李2狗子 | 来源:发表于2018-11-27 12:11 被阅读0次

通过CNN,我们获得了一个feature maps, 维度为 height, width, channel。假设其为V

  1. reshape V to U,

U = [u_1, u_2, ... , u_C], u \in R^{W \times H}

  1. then, apply mean pooling for each channel to obtain the channel feature V

V = [v_1, v_2, ... , v_C] , v \in R^C

  1. channel wise attention model

b = tanh((W_c \otimes v + b_c) \oplus W_{hc}h_{t-1})

\beta = softmax(W_ib + b_i)

上面的变量的维度

  • W_c 维度为k
  • W_{hc} 维度为 k x d
  • W_i 维度为 k
  • b_c 维度为 k
  • b_i 维度为 1

\otimes 代表了 outer products

# -*- coding: utf-8 -*-

import tensorflow as tf
import numpy as np


def global_average_pool(x):
    c = x.get_shape()[-1]
    return tf.reshape(tf.reduce_mean(x, axis=[1, 2]), (-1, 1, 1, c))


def channel_wise_attention(inputs, hidden_states, k):
    inputs_shape = map(lambda x: x.value, inputs.shape)
    batch_size, height, width, channel = inputs_shape
    hidden_states_shape = map(lambda x: x.value, hidden_states.shape)
    batch_size1, embed_size = hidden_states_shape
    d = embed_size
    assert batch_size == batch_size1, "inputs and hidden_state should have the same batch size"

    with tf.variable_scope("channel_wise_attention") as scope:
        Wc = tf.get_variable("Wc", shape=(k,), dtype=tf.float32, initializer=tf.random_uniform_initializer(-1, 0))
        Whc = tf.get_variable("Whc", shape=(k, d), dtype=tf.float32, initializer=tf.random_uniform_initializer(-1, 0))
        Wi = tf.get_variable("Wi", shape=(k, ), dtype=tf.float32, initializer=tf.random_uniform_initializer(-1, 0))
        bc = tf.get_variable("bc", shape=(k, ), dtype=tf.float32, initializer=tf.random_uniform_initializer(-0.1, 0.1))
        bi = tf.get_variable("bi", shape=(), dtype=tf.float32, initializer=tf.random_uniform_initializer(-0.1, 0.1))

        # pool
        inputs = global_average_pool(inputs)        # batch_size, 1, 1, channel
        inputs = tf.reshape(inputs, (batch_size, channel))

        #outer dot
        inputs = tf.reshape(inputs, (-1, 1))
        Wc = tf.reshape(Wc, shape=(-1, 1))
        dot1 = tf.matmul(inputs, Wc, transpose_b=True)
        dot1 = dot1 + bc
        dot1 = tf.reshape(dot1, (batch_size, channel, k))

        dot2 = tf.matmul(hidden_states, Whc, transpose_b=True)
        dot2 = tf.expand_dims(dot2, 1)  # batch_size, 1, k

        b = tf.tanh(dot1 + dot2)        # batch_size, channel, k
        b = tf.reshape(b, shape=(-1, b.get_shape()[-1]))
        dot3 = tf.matmul(b, tf.expand_dims(Wi, -1))
        dot3 = tf.reshape(dot3, (batch_size, channel))
        beta = dot3 + bi

        attention = tf.nn.softmax(beta)             # batch_size, channel
        attention = tf.expand_dims(attention, 1)
        attention = tf.expand_dims(attention, 1)

        output = tf.multiply(inputs, attention)
        return attention, output


if __name__ == '__main__':
    inputs = tf.constant(np.asarray(range(1, 2 * 4 * 4 * 3 + 1), dtype=np.float32).reshape((2, 4, 4, 3)))
    hidden_states = tf.constant(np.asarray(range(97, 97 + 2 * 5), dtype=np.float32).reshape((2, 5)))
    k = 6
    attention, output = channel_wise_attention(inputs, hidden_states, k)

    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True
    session = tf.Session(config=config)

    for variable in tf.trainable_variables():
        print variable.name, variable.dtype

    initializer = tf.global_variables_initializer()
    session.run(initializer)

    for o in session.run([attention, output]):
        print o

计算公式为:
b = tanh((W_c \otimes V + b_c) \oplus W_{hc}h_{t-1})

\beta = softmax(W_ib + b_i)

公式中V的维度为 channel, 而W_c维度为 k, 两者做outer products的维度为 k, channel。 然后和b_c做加法运算,这里用到了tf的broadcasting, 所以维度依然为 k, channel。然后再加上 W_{hc}h_{t-1}(维度为k), 所以两者之和\oplus依然是使用的broadcasting方式的加法。
W_i维度为k,b的维度为k, channel, 所以这里的W_ib的结果维度为channel,然后加上b_i, 这个变量说是维度为1, 在tensorflow中变量维度的表示实际为 shape=()

但是在实际的计算过程中: 数据还要在公式的原始的维度上加上一个 batch_size 维度。在实际的编码过程中可能需要经常的reshape, transpose操作。

操作 变量维度
输入 inputs.shape=(batch_size, height, width, channel)
global_average_pool(inputs) inputs.shape=(batch_size, 1, 1, channel)
reshape(inputs, shape=(batch_size, channel) inputs.shape=(batch_size, channel)
# W_c \otimes V -
tf.reshape(inputs, (-1, 1) inputs.shape=(batch_size * channel, 1)
Wc = tf.reshape(Wc, shape=(-1, 1) Wc.shape=(k, 1)
dot1 = tf.matmul(inputs, Wc, transpose_b=True) dot1.shape=(batch_size * channel, k)
dot1 = dot1 + bc dot1.shape=(batch_size * channel, k)
dot1 = tf.reshape(dot1, (batch_size, channel, k)) dot1.shape=(batch_size, channel, k)
- hidden_states.shape=(batch_size, d), Whc.shape=(k, d)
dot2 = tf.matmul(hidden_states, Whc, transpose_b=True) dot2.shape=(batch_size, k)
- 需要dot2在channel维度上动态的broadcasting加
b = tf.tanh(dot1 + dot2) b.shape=(batch_size, channel, k)
b = tf.reshape(b, (-1, b.get_shape()[-1]) b.shape=(batch_size * channel, k)
dot3 = tf.matmul(b, tf.expand_dims(Wi, -1)) dot3.shape=(batch_size * channel, 1)
dot3 = tf.matmul(dot3, (batch_size, channel)) dot2.shape=(batch_size, channel)
beta = dot3 + bi beta.shape=(batch_size, channel)
attention = tf.nn.softmax(beta) attention.shape=(batch_size, channel)

softmax默认的操作维度为 -1 , 最后得到了(batch_size, channel)这个维度的tensor attention, 然后attention做expand_dims转变为 (batch_size, 1, 1, channel)维度的tensor。

相关文章

网友评论

      本文标题:Tensorflow Channel-Wise Attentio

      本文链接:https://www.haomeiwen.com/subject/nhiiqqtx.html