关于a.sum(axis=1),网上有很多博客解释,但都说的不明白。我在stackoverflow上发现一个解释[1],很容易记忆,录在这里。
code.py
import numpy as np
if __name__ == '__main__':
a = np.arange(30).reshape(2, 3, 5)
print(a)
print(a[0,:,:]+a[1,:,:])
print("\n")
print(a.sum(axis=0))
print("\n\n")
print(a[:,0,:]+a[:,1,:]+a[:,2,:])
print("\n")
print(a.sum(axis=1))
print("\n\n")
print(a[:,:,0]+a[:,:,1]+a[:,:,2]+a[:,:,3]+a[:,:,4])
print("\n")
print(a.sum(axis=2))
a.sum(axis=0)的结果同a[0,:,:]+a[1,:,:]相同。
a.sum(axis=1)的结果同a[:,0,:]+a[:,1,:]+a[:,2,:]相同。
a.sum(axis=2)的结果同a[:,:,0]+a[:,:,1]+a[:,:,2]+a[:,:,3]+a[:,:,4]相同。
Reference:
[1] How does NumPy Sum (with axis) work?
网友评论