美文网首页
pytorch scatter_函数

pytorch scatter_函数

作者: 全都给我Pass | 来源:发表于2020-08-19 12:20 被阅读0次

    这里表示,如果是一个3维张量,当dim设置为0(行)的时候,src参数的张量形状与self参数的张量形状,在除dim=0以外的维度,需要大小相同。及第二维大小都为j,第三维大小都为k。

    这里简单举一个二维张量的例子,当dim设置为0的时候,src参数的张量 要求与self参数的张量在列上的大小相同(dim=1)。当dim设置为1的时候,src参数的张量 要求与self参数的张量在行上的大小相同(dim=0)。

    以下为jupyter的案例及输出:

    #%%

    # 函数scatter_(dim, index, src) → Tensor  从src中取index位置的元素,index元素值表示要写入self的位置

    # 例子1:原始self是一个3行5列的随机张量,index是一个2行5列的张量,dim=0(行)

    # dim=0(行),要求self的列和 index的列大小一致,index元素值表示self的行下标,因此不能超过self的行

    index = torch.tensor([[0,1,2,0,0], [2,0,0,1,2]]) # 2行5列

    src = torch.rand(2,5)

    print(src)

    z1 = torch.zeros(3,5).scatter_(0, index, src)

    z1

    # 例子2:原始self是一个2行4列的0张量,index是一个2行1列的张量,

    # dim=1(列),要求self的行和 index的行大小一致,index元素值表示self的列下标,因此不能超过self的列

    index2 = torch.tensor([[2], [3]])

    index2.shape

    z2 = torch.zeros(2,4).scatter_(1, index2,1.23)

    z2

    # 例子3:

    # dim=0(行),要求self的列和 index的列大小一致,index元素值表示self的行下标,因此不能超过self的行

    z3 = torch.zeros(2,4).scatter_(0, torch.tensor([[0,1,0,0]]),1.23)

    z3

    相关文章

      网友评论

          本文标题:pytorch scatter_函数

          本文链接:https://www.haomeiwen.com/subject/fiprjktx.html