若把PyTorch看做支持GPU和自动微分功能的Numpy,那么PyTorch中的dim和Numpy中的axis是一样的概念。
torch.sum(input, dim, keepdim=False, *, dtype=None) -> Tensor
numpy.sum(a, axis=None, dtype=None, out=None, keepdims=<no value>, initial=<no value>, where=<no value>)
首先, dim和axis的默认值是None,意思是对所有元素进行操作,范例代码如下:
import torch
a = torch.arange(0,2*3*4)
a = a.reshape([2,3,4])
print(a,a.shape)
print('---------------------------------------------')
print(torch.sum(a),torch.sum(a).shape)
import numpy as np
print('---------------------------------------------')
print('---------------------------------------------')
b = np.arange(0,2*3*4)
b = b.reshape([2,3,4])
print(b,b.shape)
print('---------------------------------------------')
print(np.sum(b),np.sum(b).shape)
运行结果:
![](https://img.haomeiwen.com/i10758717/6aefea8dbefe5f2b.png)
其次,axis和dim可以直觉化理解为:其它axis和dim保持不动,指定的dim坍塌(collapse), dim=0, dim=1, dim=2的Collapse示意图如下所示:
![](https://img.haomeiwen.com/i10758717/68a0e4d6e725f3f9.gif)
![](https://img.haomeiwen.com/i10758717/1fd13fed70efa527.gif)
![](https://img.haomeiwen.com/i10758717/8ec2ac4064b267ff.gif)
范例程序如下
import torch
a = torch.arange(0,2*3*4)
a = a.reshape([2,3,4])
print(a,a.shape)
print('---------------------------------------------')
print(torch.sum(a,dim=0),torch.sum(a,dim=0).shape)
import numpy as np
print('---------------------------------------------')
print('---------------------------------------------')
b = np.arange(0,2*3*4)
b = b.reshape([2,3,4])
print(b,b.shape)
print('---------------------------------------------')
print(np.sum(b,axis=0),np.sum(b,axis=0).shape)
运行结果:
![](https://img.haomeiwen.com/i10758717/1d86771e39302ba0.png)
网友评论