美文网首页
直觉化理解PyTorch中的dim和Numpy中的axis

直觉化理解PyTorch中的dim和Numpy中的axis

作者: LabVIEW_Python | 来源:发表于2022-04-30 10:59 被阅读0次

若把PyTorch看做支持GPU和自动微分功能的Numpy,那么PyTorch中的dim和Numpy中的axis是一样的概念。

torch.sum(input, dim, keepdim=False, *, dtype=None) -> Tensor
numpy.sum(a, axis=None, dtype=None, out=None, keepdims=<no value>, initial=<no value>, where=<no value>)

首先, dim和axis的默认值是None,意思是对所有元素进行操作,范例代码如下:

import torch
a = torch.arange(0,2*3*4)
a = a.reshape([2,3,4])
print(a,a.shape)
print('---------------------------------------------')
print(torch.sum(a),torch.sum(a).shape)

import numpy as np
print('---------------------------------------------')
print('---------------------------------------------')
b = np.arange(0,2*3*4)
b = b.reshape([2,3,4])
print(b,b.shape)
print('---------------------------------------------')
print(np.sum(b),np.sum(b).shape)
运行结果: axis=None,“坍塌”为标量

其次,axis和dim可以直觉化理解为:其它axis和dim保持不动,指定的dim坍塌(collapse), dim=0, dim=1, dim=2的Collapse示意图如下所示:

dim=0,维度0“坍塌” dim=1,维度1“坍塌” dim=2,维度2“坍塌”
范例程序如下
import torch
a = torch.arange(0,2*3*4)
a = a.reshape([2,3,4])
print(a,a.shape)
print('---------------------------------------------')
print(torch.sum(a,dim=0),torch.sum(a,dim=0).shape)

import numpy as np
print('---------------------------------------------')
print('---------------------------------------------')
b = np.arange(0,2*3*4)
b = b.reshape([2,3,4])
print(b,b.shape)
print('---------------------------------------------')
print(np.sum(b,axis=0),np.sum(b,axis=0).shape)
运行结果: dim=0, 0维“坍塌”

相关文章

网友评论

      本文标题:直觉化理解PyTorch中的dim和Numpy中的axis

      本文链接:https://www.haomeiwen.com/subject/snxyyrtx.html