这里表示,如果是一个3维张量,当dim设置为0(行)的时候,src参数的张量形状与self参数的张量形状,在除dim=0以外的维度,需要大小相同。及第二维大小都为j,第三维大小都为k。
这里简单举一个二维张量的例子,当dim设置为0的时候,src参数的张量 要求与self参数的张量在列上的大小相同(dim=1)。当dim设置为1的时候,src参数的张量 要求与self参数的张量在行上的大小相同(dim=0)。
以下为jupyter的案例及输出:
#%%
# 函数scatter_(dim, index, src) → Tensor 从src中取index位置的元素,index元素值表示要写入self的位置
# 例子1:原始self是一个3行5列的随机张量,index是一个2行5列的张量,dim=0(行)
# dim=0(行),要求self的列和 index的列大小一致,index元素值表示self的行下标,因此不能超过self的行
index = torch.tensor([[0,1,2,0,0], [2,0,0,1,2]]) # 2行5列
src = torch.rand(2,5)
print(src)
z1 = torch.zeros(3,5).scatter_(0, index, src)
z1
# 例子2:原始self是一个2行4列的0张量,index是一个2行1列的张量,
# dim=1(列),要求self的行和 index的行大小一致,index元素值表示self的列下标,因此不能超过self的列
index2 = torch.tensor([[2], [3]])
index2.shape
z2 = torch.zeros(2,4).scatter_(1, index2,1.23)
z2
# 例子3:
# dim=0(行),要求self的列和 index的列大小一致,index元素值表示self的行下标,因此不能超过self的行
z3 = torch.zeros(2,4).scatter_(0, torch.tensor([[0,1,0,0]]),1.23)
z3
网友评论