张量形状的理解与相关操作
一.张量的形状的判断
Drawing (1).png这里的A,B,C分布表示维度0,1,2
那如何理解shape呢,由左图我们可以看到
和A这个方括号同维度的有[1 2]
,[4 5]
两个,所以维度0的长度为2,
而和B同维度的有0
,2
两个,所以维度1的长度为2
所以左边张量的shape=[2,2]
同理右图的张量,
和A方括号同维度的只有[[1 2 3] [4 5 6]]
,所以维度0长度为1
和B方括号同维度的有[1 2 3]
, [4 5 6]
两个,所以维度1长度为2
和C同维度的有1
,2
,3
,所以维度2长度为3
所以shape=[1,2,3]
二. tf.squeeze(input, squeeze_dims=None, name=None),维度去除
1.去掉所有长度为1的维度(相当于去除那个维度的括号)。
举个栗子:
#coding=utf-8
import tensorflow as tf;
import numpy as np;
B = np.array([[[[1],[2],[3] ],[[4],[5],[6] ]]])
#去除维度0和维度3,因为这两个维度长度都为1
y = tf.squeeze(B,0)
with tf.Session() as sess:
print (sess.run(y),'\n')
输出:
[[1 2 3]
[4 5 6]]
2.也可以去掉指定索引的维度(该维度长度必须为1):
#coding=utf-8
import tensorflow as tf;
import numpy as np;
#shape=[1,2,3,1]
B = np.array([[[[1],[2],[3] ],[[4],[5],[6] ]]])
#去除维度0
y = tf.squeeze(B,[0])
with tf.Session() as sess:
print (sess.run(y),'\n')
输出:
[[[1]
[2]
[3]]
[[4]
[5]
[6]]]
三. tf.expand_dims(input, dim, name=None),扩展维度
作用:跟squeeze作用相反,它在维度dim上扩展一个长度为1的维度,原维度dim则被排在后面
#coding=utf-8
import tensorflow as tf;
import numpy as np;
B = np.array([[3,4],[5,6]])
# 在维度0的元素前面加括号
y = tf.expand_dims(B,0)
y1 = tf.expand_dims(B,2)
#-1表示最后一维
y2 = tf.expand_dims(B,-1)
print(B,'\n')
with tf.Session() as sess:
print ('y:shape=',y.shape,'\n',sess.run(y),'\n')
print ('y1:shape=',y1.shape,'\n',sess.run(y1),'\n')
print ('y2:shape=',y2.shape,'\n',sess.run(y2),'\n')
输出:
y:shape= (1, 2, 2)
[[[3 4]
[5 6]]]
y1:shape= (2, 2, 1)
[[[3]
[4]]
[[5]
[6]]]
y2:shape= (2, 2, 1)
[[[3]
[4]]
[[5]
[6]]]
四.tf.transpose(input, [dimension_1, dimenaion_2,..,dimension_n])
作用:交换维度
举个栗子:
A = np.array([[[1,2,3],[4,5,6]]])
#即:
[[
[1 2 3]
[4 5 6]
]]
如果x=tf.transpose(A, [0,2,1])
1.那么首先找到维度0,2,1的长度
维度0的长度: 1
维度1的长度: 2
维度2的长度: 3
2.再按序写出0,2维的形状的张量:
[
[
[ ]
[ ]
[ ]
]
]
3.若x的最后一维长度比A最后一维的长度小,则取将同列的元素按序放入x的最后一维,否则将x的同行元素按序放入最后一维,这里x和A的最后一维长度分别为2,3,所以将同列写入最后一维,最后的结果为:
[[[1 4]
[2 5]
[3 6]]]
代码验证:
import tensorflow as tf;
import numpy as np;
A = np.array([[[1,2,3],[4,5,6]]])
x = tf.transpose(A, [0,2,1])
y = tf.transpose(A, [0,1,2])
with tf.Session() as sess:
print ('A:\n',A,'\n')
print ('x:\n',sess.run(x),'\n')
print ('y:\n',sess.run(y),'\n')
输出:
A:
[[[1 2 3]
[4 5 6]]]
x:
[[[1 4]
[2 5]
[3 6]]]
y:
[[[1 2 3]
[4 5 6]]]
网友评论