美文网首页
WHAT IS PYTORCH

WHAT IS PYTORCH

作者: 碎嘴俞 | 来源:发表于2019-06-02 14:12 被阅读0次

    Tensor

    Resizing

    If you want to resize/reshape tensor, you can use torch.view:

    x = torch.randn(4, 4)
    y = x.view(16)
    z = x.view(-1, 8)  # the size -1 is inferred from other dimensions
    print(x.size(), y.size(), z.size())
    

    Out:

    torch.Size([4, 4]) torch.Size([16]) torch.Size([2, 8])
    

    If you have a one element tensor, use .item to get the value as a Python number

    x = torch.randn(1)
    print(x)
    print(x.item())
    

    Out

    tensor([-0.2028])
    -0.20277611911296844
    
    NumPy Bridge

    The Torch Tensor and NumPy array will share their underlying memory locations(if the Torch Tensor is on CPU), and changing one will change the other.

    a = torch.ones(5)
    b = a.numpy()
    print(b)
    a.add_(1)
    print(a)
    print(b)
    

    Out

    array([1.,1.,1.,1.,1.], dtype=float32)
    tensor([2., 2., 2., 2., 2.])
    array([2.,2.,2.,2.,2.], dtype=float32)
    

    Converting NumPy Array to Torch Tensor

    import numpy as np
    a = np.ones(5)
    b = torch.from_numpy(a)
    np.add(a, 1, out=a)
    print(a)
    print(b)
    

    Out

    [2. 2. 2. 2. 2.]
    tensor([2., 2., 2., 2., 2.], dtype=torch.float64)
    

    相关文章

      网友评论

          本文标题:WHAT IS PYTORCH

          本文链接:https://www.haomeiwen.com/subject/rwlxxctx.html