美文网首页
tf.concat用法浅析

tf.concat用法浅析

作者: 陈晓峥 | 来源:发表于2018-10-31 16:57 被阅读0次

    tf.concat(concat_dim, values, name='concat')

    第一个参数concat_dim:必须是一个数,表明在哪一维上连接

    values 代表你要连接的矩阵

    废话不多说 直接上代码

    t1 = [[[[1, 2, 3], [4, 5, 6],[20, 21, 22]]]]

    t2 = [[[[7, 8, 9], [10, 11, 12], [25, 26, 27]]]]

    t3 = [[[[13, 14, 15], [16, 17, 18], [28, 29, 20]]]]

    print(tf.Variable(t1).shape)   #(1, 1, 3, 3)

    print(tf.Variable(t2).shape)   #(1, 1, 3, 3)

    print(tf.Variable(t3).shape)  #(1, 1, 3, 3)

    data = tf.Variable(tf.concat([t1, t2, t3], 0))     # 0 代表我将t1的第一个index相加 为 3

    data1 = tf.Variable(tf.concat([t1, t2, t3], 3))   # 3 代表我将t1的第四个index相加  为9

    init = tf.global_variables_initializer()

    with tf.Session() as sess:

    sess.run(init)

    print("test:\n", sess.run(data), "data =", data.shape)    # (3, 1, 3, 3)

    print("test:\n", sess.run(data1), "data1 =", data1.shape) #(1, 1, 3, 9)


    如果tf.concat 超过了矩阵的长度如将data1 = tf.Variable(tf.concat([t1, t2, t3], 3)) 改为

    data1 = tf.Variable(tf.concat([t1, t2, t3], 4)) 则会报错

    报错信息为

    ValueError: Shape must be at least rank 5 but is rank 4 for 'concat_1' (op: 'ConcatV2') with input shapes: [1,1,3,3], [1,1,3,3], [1,1,3,3], [] and with computed input tensors: input[3] = <4>.

    大概意思实 你的t1 t2 t3 的张量最大是到第四个维度, 而你输入4 代表要对张量的第五个维度进行相加,类似于数组越界,所以报错

    如有问题欢迎大家指正,谢谢

    相关文章

      网友评论

          本文标题:tf.concat用法浅析

          本文链接:https://www.haomeiwen.com/subject/yvwatqtx.html