tf.concat(concat_dim, values, name='concat')
第一个参数concat_dim:必须是一个数,表明在哪一维上连接
values 代表你要连接的矩阵
废话不多说 直接上代码
t1 = [[[[1, 2, 3], [4, 5, 6],[20, 21, 22]]]]
t2 = [[[[7, 8, 9], [10, 11, 12], [25, 26, 27]]]]
t3 = [[[[13, 14, 15], [16, 17, 18], [28, 29, 20]]]]
print(tf.Variable(t1).shape) #(1, 1, 3, 3)
print(tf.Variable(t2).shape) #(1, 1, 3, 3)
print(tf.Variable(t3).shape) #(1, 1, 3, 3)
data = tf.Variable(tf.concat([t1, t2, t3], 0)) # 0 代表我将t1的第一个index相加 为 3
data1 = tf.Variable(tf.concat([t1, t2, t3], 3)) # 3 代表我将t1的第四个index相加 为9
init = tf.global_variables_initializer()
with tf.Session() as sess:
sess.run(init)
print("test:\n", sess.run(data), "data =", data.shape) # (3, 1, 3, 3)
print("test:\n", sess.run(data1), "data1 =", data1.shape) #(1, 1, 3, 9)
如果tf.concat 超过了矩阵的长度如将data1 = tf.Variable(tf.concat([t1, t2, t3], 3)) 改为
data1 = tf.Variable(tf.concat([t1, t2, t3], 4)) 则会报错
报错信息为
ValueError: Shape must be at least rank 5 but is rank 4 for 'concat_1' (op: 'ConcatV2') with input shapes: [1,1,3,3], [1,1,3,3], [1,1,3,3], [] and with computed input tensors: input[3] = <4>.
大概意思实 你的t1 t2 t3 的张量最大是到第四个维度, 而你输入4 代表要对张量的第五个维度进行相加,类似于数组越界,所以报错
如有问题欢迎大家指正,谢谢
网友评论