美文网首页
pytorch基本操作

pytorch基本操作

作者: sheng_pan_ai | 来源:发表于2019-02-24 15:34 被阅读0次

    一. pytorch基本概念

    张量(Tensors)

    x = torch.Tensor(3,5) 构建未初始化的张量
    x = torch.rand(3,5) 构建一个随机初始化的矩阵
    x.size() 或者 x.shape 获取矩阵的大小

    二. pytorch 操作

    语法1:

    x + y   x = torch.rand(2,3)
    

    语法2:

     torch.add(x,y)
    

    语法3:

    result = torch.Tensor(3,5) torch.add(x,y,out=result)
    

    语法4:

    y ._add(x)
    

    原地操作 (in-place)
    任何在原地(in-place)改变张量的操作都有一个'_'后缀。

    三. numpy桥

    把一个torch张量转换为numpy数组或者反过来都是很简单的。
    Torch张量和numpy数组将共享潜在的内存,改变其中一个也将改变另一个。a.add_(1)
    把Torch张量转换为numpy数组 :

    a = torch.ones(5)  b= a.numpy()
    

    把numpy数组转换为torch张量:

    torch.from_numpy(b)
    

    所有在CPU上的张量,除了字符张量,都支持在numpy之间转换。
    你可以使用所有的numpy索引操作: print(a[:,1])

    四. CUDA张量

    使用.cuda函数可以将张量移动到GPU上。

    if torch.cuda.is_available() :
        x = x.cuda()
    

    五. pytorch函数操作

    torch.max

    返回输入tensor中所有元素的最大值

    torch.max(input,dim)
    

    按维度dim 返回最大值

    torch.max)(a,0) 
    

    返回每一列中最大值的那个元素,且返回索引(返回最大元素在这一列的行索引)

    torch.max(a,1)
    

    返回每一行中最大值的那个元素,且返回其索引(返回最大元素在这一行的列索引)

    torch.max()[0]
    

    只返回最大值的每个数

    troch.max()[1]
    

    只返回最大值的每个索引

    torch.eq

    target.eq(source)
    target.eq(source).sum()  统计相等的个数  输出tensor(2)
    

    torch.view

    a.view(i,j)
    

    表示将原矩阵转化为i行j列的形式 , i为-1表示不限制行数

    torch.squeeze()

    压缩矩阵

    a.squeeze(i)
    

    压缩第i维,如果这一维维数是1,则这一维可有可无,便可以压缩

    torch.unsqueeze()

     unsqueeze(i)
    

    表示将第i维设置为1
    squeeze、unsqueeze操作不改变原矩阵

    torch.cat()

    cat(seq,[dim],out=None) 
    

    seq 表示要连接的两个序列.dim表示以哪个维度连接. dim=0横向连接,dim=1 纵向连接.

    a = torch.rand((10,2))
    b = torch.rand((10,2)) 
    c = torch.cat((a,b),dim=0)  横向连接  按行拼接,结构列数不变,行变多
    d = torch.cat((a,b),dim=1) 纵向连接  按列拼接,结构行数不变,需要列相同
    

    相关文章

      网友评论

          本文标题:pytorch基本操作

          本文链接:https://www.haomeiwen.com/subject/rgfbyqtx.html