美文网首页深度学习研究所
tf.nn.conv2d中传入的第一二个参数的数据格式问题

tf.nn.conv2d中传入的第一二个参数的数据格式问题

作者: 西方失败9527 | 来源:发表于2017-09-17 15:48 被阅读0次

    看到知乎上这样一个问题:

    下面这个图所示,输入数据是一个2个通道3*3的数据,过滤器是一个具有两个通道的2*2的数据,按照一般卷积过程,即如果所示结果是一个通道的2*2的数据。

    但是在tensorflow中,我们如下实现:

    k = tf.constant([ 1,2 ,3,4,

                                5,6,7,8], dtype=tf.float32, name='k')

    i = tf.constant([

                           1, 3, 5,

                           1, 3, 5,

                           1, 3, 5,

                            2, 4, 6,

                            2, 4, 6,

                            2, 4, 6

                            ], dtype=tf.float32, name='i')

    kernel = tf.reshape(k, [2, 2, 2, 1], name='kernel')

    image  = tf.reshape(i, [1, 3, 3, 2], name='image')

    #res = tf.nn.conv2d(image, kernel, [1, 1, 1, 1], "VALID")

    res = tf.squeeze(tf.nn.conv2d(image, kernel, [1, 1, 1, 1], "VALID"))# VALID means no padding

    with tf.Session() as sess:

                print(sess.run(res))

    结果不对

    原因原来是data_format 参数的问题,图像数据格式定义了一批图片数据的存储顺序。在调用 TensorFlow API 时会经常看到 data_format 参数:

    data_format 默认值为 "NHWC",也可以手动设置为 "NCHW"。这个参数规定了 input Tensor 和 output Tensor 的排列方式。

    data_format 设置为 "NHWC" 时,排列顺序为 [batch, height, width, channels];

                          设置为 "NCHW" 时,排列顺序为 [batch, channels, height, width]。

    其中 N 表示这批图像有几张,H 表示图像在竖直方向有多少像素,W 表示水平方向像素数,C 表示通道数(例如黑白图像的通道数 C = 1,而 RGB 彩色图像的通道数 C = 3)。为了便于演示,我们后面作图均使用 RGB 三通道图像。两种格式的区别如下图所示:

    NCHW 中,C 排列在外层,每个通道内像素紧挨在一起,即 'RRRRRRGGGGGGBBBBBB' 这种形式。

    NHWC 格式,C 排列在最内层,多个通道对应空间位置的像素紧挨在一起,即 'RGBRGBRGBRGBRGBRGB' 这种形式。

    于是我们的程序中将数据顺序修改即可:

    k = tf.constant([

    1, 5,

    2, 6,

    3, 7,

    4, 8

    ], dtype=tf.float32, name='k')

    i = tf.constant([

    1, 2, 3,

    4, 5, 6,

    1, 2, 3,

    4, 5, 6,

    1, 2, 3,

    4, 5, 6

    ], dtype=tf.float32, name='i')

    kernel = tf.reshape(k, [2, 2, 2, 1], name='kernel')

    image  = tf.reshape(i, [1, 3, 3, 2], name='image')

    #res = tf.nn.conv2d(image, kernel, [1, 1, 1, 1], "VALID")

    res = tf.squeeze(tf.nn.conv2d(image, kernel, [1, 1, 1, 1], "VALID"))# VALID means no padding

    with tf.Session() as sess:

                print(sess.run(image))

                print("------------------")

                print(sess.run(kernel))

                print("------------------")

               print(sess.run(res))

    最终能如愿以偿得到如图右边的结果。不过feature map的172应该改为174,手算也该如此

    主要参考:http://mp.weixin.qq.com/s/I4Q1Bv7yecqYXUra49o7tw

    相关文章

      网友评论

        本文标题: tf.nn.conv2d中传入的第一二个参数的数据格式问题

        本文链接:https://www.haomeiwen.com/subject/axiesxtx.html