美文网首页
[PyTorch中文文档]-Package参考-torch-张量

[PyTorch中文文档]-Package参考-torch-张量

作者: 六千宛 | 来源:发表于2021-08-04 16:10 被阅读0次

    torch.is_tensor

    torch.is_tensor(obj)
    

    如果obj是tensor的话返回true,否则返回false。

    对应的另一个方法是:

    isinstance(obj, Tensor)
    

    注意,torch.is_tensor(obj)是torch的一个方法,而isinstance(obj, Tensor)是python自带的一个方法,这两个是等价的。当然isinstance(obj, type)这个方法可以检查任何类型,如果检查出obj是type类型返回true,否则返回false。

    isinstance(obj, Tensor)这种方法更适合于静态检查(例如更适合mypy等静态检查工具进行检查)并且也更加直观(这个倒是我觉得两个方法都挺直观的),所以更推荐使用isinstance(obj, Tensor)这种方法。

    torch.is_storage

    torch.is_storage(obj)
    

    测试obj是不是storage类型,如果是的话就返回True,否则返回False。

    什么是Storage?
      Storage类型是pytorch中的一个类型,它与tensor是对应的。
    tensor 分为头信息区(Tensor)和存储区(Storage)。

    信息区(Tensor)主要存储tensor的形状(size)、步长(stride)、数据类型(type)等信息,其真正的数据保存为连续数组,存储在存储区(Storage)中。

    一般来说pytorch中tensor的数据很大,可能是成千上万的,所以我们信息区(Tensor)一般来说占用的内存比较少,主要内存的占用取决于tensor中元素的数目,也就是存储区(Storage)的大小。

    torch.set_default_tensor_type(t)

    torch.set_default_tensor_type(t)
    

    这个方法的意思是设置pytorch中默认的浮点类型,一般使用pytorch进行运算时候使用的都是浮点数来进行计算,所以设置默认浮点数有时候也很重要。

    虽然这个方法和曾经的torch.set_default_dtype(d)确实功能很相似,但是实际上今天介绍的这个方法更强大一些(注意两个方法都只可以设置浮点数的默认类型,不可以设置整型的默认类型)。当然这个方法使用后也可以使用torch.get_default_dtype()来获取设置的默认浮点类型。

    Tensor有不同的数据类型,每种类型分别有对应CPU和GPU版本(HalfTensor除外)。默认的Tensor是FloatTensor,可通过torch.set_default_tensor_type修改默认tensor类型(如果默认类型为GPU tensor,则所有操作都将在GPU上进行),HalfTensor是专门为GPU设计的,相同元素个数使用的空间更少,解决显存不足的问题,但是由于精度不足可能会出现溢出的问题。

    pytorch中可用的浮点类型


    image.png
    import torch
    torch.set_default_tensor_type(torch.cuda.DoubleTensor)
    a = torch.tensor([2., 3])
    print(a.dtype,a.device)
    

    torch.numel

    torch.numel(input)->int
    

    numel就是"number of elements"的简写。numel()可以直接返回int类型的元素个数

    import torch
    a = torch.randn(1, 2, 3, 4)
    b = a.numel()
    print(type(b)) #int
    print(b) #24
    

    torch.set_printoptions

    torch.set_printoptions(precision=None, threshold=None, edgeitems=None, linewidth=None, profile=None)
    

    precision是每一个元素的输出精度,默认是八位;
    threshold是输出时的阈值,当tensor中元素的个数大于该值时,进行缩略输出,默认时1000;
    edgeitems是输出的维度,默认是3;
    linewidth字面意思,每一行输出的长度;
    profile– pretty打印的完全默认值。 可以覆盖上述所有选项 (默认为short, full)

    相关文章

      网友评论

          本文标题:[PyTorch中文文档]-Package参考-torch-张量

          本文链接:https://www.haomeiwen.com/subject/tmuuvltx.html