美文网首页
tf.concat(), tf.stack(), tf.unst

tf.concat(), tf.stack(), tf.unst

作者: D_Major | 来源:发表于2019-03-28 17:59 被阅读0次

    参考 https://blog.csdn.net/loseinvain/article/details/79638183
    https://blog.csdn.net/chengshuhao1991/article/details/78545723
    输入两个二维数组如下:

    a = tf.constant([[1,2,3],[3,4,5]]) # shape (2,3)
    b = tf.constant([[7,8,9],[10,11,12]]) # shape (2,3)
    

    tf.concat相当于numpy中的np.concatenate函数,用于将两个张量在某一个维度(axis)合并起来,例如:

    ab1 = tf.concat([a,b], axis=0) # shape(4,3)
    [[ 1  2  3]
     [ 3  4  5]
     [ 7  8  9]
     [10 11 12]]
    
    ab2 = tf.concat([a,b], axis=1) # shape(2,6)
    [[ 1  2  3  7  8  9]
     [ 3  4  5 10 11 12]]
    

    tf.stack其作用类似于tf.concat,都是拼接两个张量,而不同之处在于,tf.concat拼接的是除了拼接维度axis外其他维度的shape完全相同的张量,并且产生的张量的阶数不会发生变化,而tf.stack则会在新的张量阶上拼接,产生的张量的阶数将会增加。

    tf.stack()就是以指定的轴axis,将一个维度为R的张量数组转变成一个维度为R+1的张量。即将张量数组以指定的轴,提高一个维度。
    假设要转变的张量数组values(如[x, y])的长度为N,其中的每个张量(如x, y)的形状为(A, B, C)。
    如果轴axis=0,则转变后的张量的形状为(N, A, B, C)。
    如果轴axis=1,则转变后的张量的形状为(A, N, B, C)。
    如果轴axis=2,则转变后的张量的形状为(A, B, N, C)。其它情况依次类推。

    a = tf.constant([[1,2,3],[3,4,5]])
    b = tf.constant([[7,8,9],[10,11,12]])
    ab3 = tf.stack([a,b], axis=0)
    [[[ 1  2  3]
      [ 3  4  5]]
    
     [[ 7  8  9]
      [10 11 12]]]
    
    ab4 = tf.stack([a,b], axis=1)
    [[[ 1  2  3]
      [ 7  8  9]]
    
     [[ 3  4  5]
      [10 11 12]]]
    
    ab5 = tf.stack([a,b], axis=2)
    [[[ 1  7]
      [ 2  8]
      [ 3  9]]
    
     [[ 3 10]
      [ 4 11]
      [ 5 12]]]
    


    ‘x’ is [[1,1,1,1],[2,2,2,2],[3,3,3,3]],形状是(3,4),维度为2
    ‘y’ is [[4,4,4,4],[5,5,5,5],[6,6,6,6]],形状是(3,4),维度为2
    stack([x,y]) => [[[1,1,1,1],[2,2,2,2],[3,3,3,3]], [[4,4,4,4],[5,5,5,5],[6,6,6,6]]] # axis的值默认为0。输出的形状为(2, 3, 4)
    stack([x,y],axis=1) => [[[1,1,1,1],[4,4,4,4]],[[2,2,2,2],[5,5,5,5]],[[3,3,3,3],[6,6,6,6]]] # axis的值为1。输出的形状为(3, 2, 4)
    stack([x,y],axis=2) => [[[1,4],[1,4],[1,4],[1,4]],[[2,5],[2,5],[2,5],[2,5]],[[3,6],[3,6],[3,6],[3,6]]]# axis的值为2。输出的形状为(3, 4, 2)

    axis可这样理解:stack要将一组N个相同形状的张量(如[x, y])提高一个维度。axis就是在和原来形状相同的张量里,将axis指定的维度里每一个元素用拼接后的数组代替。如axis=2,表示在指定的第2个维度(数值),将原来的每一个数值(如1), 用x和y对应位置的数值拼接而成的数组(如[1, 4])代替,即从(A, B)转变为(A, B, N)。
    对两个二维数组的拼接,axis=0则表示在三维层面上进行拼接,操作单位为二维矩阵;axis=1则表示在二维层面上进行拼接,操作单位为行向量,一一对应的进行拼接;axis=2则表示在一维层面上进行拼接,操作单位为数值,进行point-wise的拼接。

    tf.unstack与tf.stack的操作相反,是将一个高阶数的张量在某个axis上分解为低阶数的张量,例如:

    a1 = tf.unstack(ab3, axis=0)
    [array([[1, 2, 3],
           [3, 4, 5]], dtype=int32), 
     array([[7, 8, 9],
           [10, 11, 12]], dtype=int32)]
    
    a2 = tf.unstack(ab3, axis=1)
    [array([[1, 2, 3],
           [7, 8, 9]], dtype=int32), 
     array([[3, 4, 5],
           [10, 11, 12]], dtype=int32)]
    

    对tf.concat()的也可以做unstack操作

    a3 = tf.unstack(ab1, axis=0)
    [array([1, 2, 3], dtype=int32), 
     array([3, 4, 5], dtype=int32), 
     array([7, 8, 9], dtype=int32), 
     array([10, 11, 12], dtype=int32)]
    
    a4 = tf.unstack(ab1, axis=2)
    [array([[1, 7], [2, 8]], dtype=int32), 
     array([[3, 9], [3, 10]], dtype=int32), 
     array([[4, 11], [5, 12]], dtype=int32)] 
    

    相关文章

      网友评论

          本文标题:tf.concat(), tf.stack(), tf.unst

          本文链接:https://www.haomeiwen.com/subject/qgycbqtx.html