美文网首页我爱编程
Numpy中stack()函数的理解

Numpy中stack()函数的理解

作者: 逆鳞L | 来源:发表于2018-03-27 11:10 被阅读0次

np.stack(array,axis,out=None),函数原型。
其中最重要是的这个axis怎么理解的。
举例说明:
arrays = [np.random.randn(3, 4) for _ in range(10)]
会生成一个 10 *( 3 * 4 )的矩阵列表。十个矩阵,每个矩阵是(3 * 4)大小。
首先说明一下axis的映射。在这个例子中,10->axis=0 ,3->axis=1

>>>np.stack(arrays,axis=1)
array([[[-0.42233185, -0.13270788, -0.47724388, -1.48881134],
        [ 0.2284937 , -0.30139984,  0.15633374,  0.04428078],
        [ 2.0193316 ,  0.1098357 , -0.32044757, -1.24868601],
        [ 0.9859909 , -0.42781564,  0.57524126,  0.58154297],
        [-0.13059124,  2.15207301,  0.36007904, -0.71344781],
        [-1.68010975,  1.25350273,  0.11073033, -0.28531604],
        [ 0.60021096, -0.18691447,  1.49261775,  0.47628294],
        [-0.18268831, -0.32463742, -0.89726008,  0.19245843],
        [-0.27384598,  0.56068318,  1.57096001,  1.11169077],
        [ 0.27035354, -0.54258351, -0.69891459,  1.84282464]],

       [[ 1.44874184, -1.6645958 ,  1.14128754, -2.26945958],
        [ 0.28754711, -1.59591539, -0.92798468, -0.05021877],
        [ 1.09050239, -0.86881164, -0.59820951, -0.39628311],
        [-1.09540304, -0.33438594, -0.71075442, -1.48691938],
        [ 0.7155825 ,  0.24710929, -0.65019501, -1.24407802],
        [-0.11059045, -1.57851632,  1.34142995, -0.44438407],
        [ 0.9258746 ,  1.62418684, -0.25380587, -1.1423341 ],
        [-1.76337136,  0.55031978,  1.25834475,  0.53257722],
        [ 0.05755626,  1.16156935, -1.84999546,  1.57175386],
        [ 0.48836813, -0.21907532, -0.78655392,  0.51705705]],

       [[-0.24451876, -0.09881284,  1.17611246,  0.81276037],
        [ 0.89510841,  0.9106155 ,  0.4923826 , -0.07364133],
        [-0.0670429 ,  0.72968107, -1.31473173, -0.31313322],
        [ 0.62314248,  0.97792175,  0.0840199 , -0.38035465],
        [ 0.70222737,  0.53761069,  0.50546661, -2.02777762],
        [-0.85454667, -0.76359383, -0.25280887, -0.94252057],
        [ 0.38294622, -0.38729216,  0.03757319, -0.48955485],
        [ 1.52718003,  1.14814816,  1.33147053, -0.50341043],
        [-0.38600834,  0.19781327, -0.35596671,  1.59331045],
        [-0.07073478, -1.4710414 ,  1.95192939, -0.83379204]]])
>>> np.stack(arrays, axis=1).shape
(3, 10, 4)

为什么会变成 3 * 10 * 4了呢。首先我们的函数是对 10 * 3 * 4 中的3,也就是axis=1,进行了堆叠。
那么这个 axis = 1,在十个矩阵中代表什么呢?代表 每个矩阵中的一行。所以这个函数的操作就是,把10矩阵中的第i行拿出来拼成一个矩阵。因为一个矩阵有三行,所以堆叠后的矩阵就是,3 * 10 * 4,这个10 * 4,就是原来矩阵中,十个矩阵的第一行,第二行,第三行,拼接而成的。所以是 3 * 10 * 4。

相关文章

网友评论

    本文标题:Numpy中stack()函数的理解

    本文链接:https://www.haomeiwen.com/subject/xmpjcftx.html